ML Case-study Interview Question: Multi-Modal, Multi-Task Learning for Hierarchical E-commerce Product Classification
Case-Study question
A fast-growing e-commerce platform hosts millions of sellers who list billions of products across many categories. Each product has unstructured text data (titles, descriptions, tags) in multiple languages and one or more images. The goal is to automatically classify these products into a hierarchical taxonomy with over 5,000 leaves spanning seven levels. The existing classification system uses only text, struggles with low coverage, and ignores non-English items. Design a robust solution that:
Handles multi-lingual text and image inputs.
Achieves high coverage and high accuracy.
Supports near real-time predictions at scale.
Produces hierarchical outputs without violating parent-child relationships.
Addresses class imbalance and sensitive category errors.
Explain how you would build, train, and deploy this system. Describe your proposed model architecture, data ingestion pipeline, feature engineering, training strategy, inference pipeline, and confidence thresholding approach.
Detailed Solution
Model Architecture
A multi-task, multi-class framework is effective. Text features can be converted into embeddings using a pre-trained multi-lingual model. Image features can be extracted using a convolutional network trained on a large visual dataset. Concatenate these embeddings and pass them through dense layers. Maintain separate multi-class output layers for each level of the taxonomy, feeding parent-level predictions back into child-level tasks.
This structure helps share information across levels. It also allows child-level errors to inform parent-level weights during backpropagation. The model can generate raw logits for each level without hard constraints on parent-child consistency. Enforce consistent hierarchical paths only during inference.
Core Training Objective
Class imbalance is addressed with weighted cross-entropy. Let y_{i,c} be the true label indicator, hat{y}_{i,c} be the predicted probability for class c of sample i, and w_c be the class weight for c. Then the loss function L for each level is cross-entropy weighted by w_c. All level-specific losses are added and backpropagated together.
N is the batch size. C is the number of classes at that level. w_c is higher for under-represented classes.
Data Parallel Training
Training uses large volumes of data. Split the dataset across multiple machines. Each machine processes a chunk in parallel. Synchronize gradients after each mini-batch. This reduces training time significantly while preserving model quality.
A simple high-level snippet (TensorFlow-style):
import tensorflow as tf
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = build_hierarchical_model() # define multi-task architecture
optimizer = tf.keras.optimizers.Adam()
model.compile(optimizer=optimizer, loss=custom_weighted_loss)
def distributed_train(dataset):
model.fit(dataset, epochs=NUM_EPOCHS, steps_per_epoch=STEPS, verbose=1)
Inference and Hierarchical Consistency
Obtain raw predictions for each level. Select the top class at Level 1. Restrict Level 2 choices to the children of that selected Level 1 class. Repeat down the hierarchy. This ensures outputs remain valid and consistent. Apply thresholds at each level to discard low-confidence predictions.
Coverage vs Accuracy Trade-offs
Higher confidence thresholds yield higher precision and lower coverage. Keep thresholds that minimize merchant friction. Some categories, like religious or ceremonial items, need extra care. Tailor thresholds to ensure minimal misclassification in sensitive areas.
Practical Implementation Details
Use streaming data pipelines that capture new products. Batch them into a distributed training pipeline at set intervals. Integrate real-time inference endpoints to classify products on creation. Store results, reclassify if confidence changes with new merchant data. Leverage store-level signals if patterns emerge (for instance, a store name strongly suggests a domain).
Follow-up question: Handling Multi-Lingual Text Variations
Explain how you would manage text features in languages with fewer training samples, especially for categories with sparse coverage.
Answer
Use a multi-lingual encoder trained on many languages. This helps learn common semantic structures. Augment with domain-specific synonyms. Employ class weighting for low-resource categories. If certain languages are underrepresented, fine-tune the encoder with targeted examples. Gather more labeled data via crowd-sourcing or merchant feedback loops. Include store-level metadata, like location or domain keywords, which might hint at a region-specific product.
Follow-up question: Mitigating Over-Confidence and Sensitive Mislabelling
How would you reduce the chances of confidently misclassifying a sensitive item?
Answer
Implement separate lower thresholds for sensitive branches. For categories such as religious, ceremonial, or medical items, raise the minimum confidence bar. Maintain a fallback mechanism to flag uncertain predictions for manual review. If a product is flagged, store teams or external reviewers can label it. Feed that label back into the training set to improve the model. Monitor these categories closely, ensuring that the final system’s coverage/precision balance aligns with business and legal requirements.
Follow-up question: Ensuring Real-Time Inference at Scale
How would you serve this model to achieve sub-second latency for large traffic?
Answer
Export the model as a single graph that includes the hierarchical post-processing logic. Host it on a scalable inference platform. Use GPUs or specialized accelerators with an autoscaling setup. Cache embeddings for frequently updated items. If real-time text changes occur, re-embed only that text. Distribute requests across multiple inference nodes. Monitor latency with metrics. Optimize your pipeline (batching, concurrency tuning, using efficient data transfer protocols) until the entire path remains under the latency target.