ML Interview Q Series: How can you adjust output probabilities after training on a downsampled imbalanced binary classification dataset?
📚 Browse the full ML Interview series here.
Comprehensive Explanation
One frequent approach to dealing with highly imbalanced data (such as 0.2% positive vs. 99.8% negative) is to downsample the majority class. This effectively rebalances the dataset in the training phase. However, because the classifier is trained on a dataset that does not mirror the actual class distribution of the real population, its predicted probabilities will be systematically biased. Specifically, the downsampled model will overestimate the probability of belonging to the minority class.
To correct for this discrepancy, we can apply a probability adjustment step after the model outputs its initial estimate. The main idea is to use knowledge about the true prevalence of each class in the original population and adjust the model’s predicted probabilities accordingly.
The Probability Adjustment Formula
When we train a model on downsampled data, the fraction of positives in the training set (let's call this fraction q) is different from the true fraction of positives in the real population (call this p). If the model outputs a probability model_prob = P(Y=1 | X)
based on the downsampled distribution, we need to rescale it to reflect the true distribution, which has proportion p of positives.
One commonly referenced formula to rebalance the probabilities is derived from Bayes’ Theorem, where we treat the classifier’s output as if it were modeling a distribution with different priors. A concise expression for the corrected probability can be written as:
Where:
\hat{p}_{model}(Y=1|X)
is the probability of the positive class predicted by the model trained on the downsampled data.p
is the actual fraction of positives in the full population (0.2% or 0.002 in our example).q
is the fraction of positives in the downsampled training set. After downsampling 1% of the majority class but retaining all minority instances,q
is significantly larger than 0.002.p/(1-p)
is the odds ratio for the positive class in the actual population.q/(1-q)
is the odds ratio for the positive class in the downsampled training set.
This formula effectively scales the model’s predicted probability by the ratio of the “true odds” to the “training-set odds.” If \hat{p}_{model}(Y=1|X)
is high (the model is leaning strongly toward a positive prediction), the adjustment may be less extreme, but if \hat{p}_{model}(Y=1|X)
is moderate, the adjustment can be more pronounced to correct for the inflated minority proportion in the training set.
Detailed Breakdown of Terms Inline
\hat{p}_{model}(Y=1|X)
is the raw probability estimated from the model trained on the downsampled data.p
is the true fraction of positives in the real-world population.(1-p)
is the true fraction of negatives in the real-world population.q
is the fraction of positives in the downsampled training set.(1-q)
is the fraction of negatives in the downsampled training set.(p/(1-p))
is how likely an individual is positive vs. negative in the true population.(q/(1-q))
is the corresponding odds in the training set.
By multiplying \hat{p}_{model}(Y=1|X)
by the factor ( (p/(1-p)) / (q/(1-q)) )
, we adjust for how much more or less likely the training set is to produce a positive compared to the real population.
Implementation Example in Python
Below is a Python snippet demonstrating how you might implement this probability correction post-processing step on your predictions. We assume you already have:
model_probs
= an array of predicted probabilities from the downsampled model.p
= the known fraction of positives in the actual population (0.002 in the example).q
= the fraction of positives in your downsampled training set.
import numpy as np
def correct_probabilities(model_probs, p, q):
# Convert fractions to odds
true_odds = p / (1 - p)
sample_odds = q / (1 - q)
ratio = true_odds / sample_odds
# Perform the odds correction
corrected = (model_probs * ratio) / (model_probs * ratio + (1 - model_probs))
return corrected
# Example usage:
model_probs = np.array([0.01, 0.2, 0.9]) # Example predicted probabilities from the downsampled model
p = 0.002 # True fraction of positives
q = 0.167 # Suppose we ended up with 16.7% positives in the downsampled dataset
corrected_probs = correct_probabilities(model_probs, p, q)
print("Corrected probabilities:", corrected_probs)
You would then use corrected_probs
as the final probability estimates on your full population.
Why This Correction Is Needed
When you downsample the majority class (in this case, label=0), your training dataset no longer reflects the real-world prevalence of the positive class. The model thus learns from a distribution that has an artificially inflated fraction of positives. If you were to directly apply the model’s output to the real population, you would systematically over-predict positives. Correcting the probabilities with the formula above aligns the model’s predictions with the actual base rate in the population.
Handling Changing Distributions
A potential complication arises if the real-world prevalence p shifts over time. If p changes, you must re-estimate it and adjust your corrections accordingly. Failing to do so will lead to miscalibrations in your predicted probabilities.
Choice of Downsampling vs. Other Methods
In practice, there are multiple approaches to deal with imbalance, including:
Downsampling.
Oversampling (e.g., SMOTE).
Class weighting in the loss function.
Focal loss adjustments.
Each method handles imbalance in a different way, but the main objective is the same: ensure that the learned model is not overly biased toward the majority class and that final probabilities are well-calibrated.
Potential Follow-up Questions
Can the same approach be extended to multi-class classification?
When addressing more than two classes, you can still attempt to adjust predicted probabilities if class priors shift. However, the math is a bit more involved because each class has its own prior. In principle, you would want to adjust each predicted probability by the ratio of its true prior to its sampled prior, ensuring the final predictions sum to 1. The common practice is to apply a correction similar to the binary case but for each class independently in a one-vs-rest style.
What if I do not know the exact prevalence p in the real-world population?
If p is unknown or only loosely estimated, you can attempt to approximate it from historical data or from a holdout sample that has not been subjected to downsampling. Another possibility is to dynamically estimate p online if you have enough labeled data streaming from real-world usage. If your p is highly uncertain, the calibration step might be less accurate, and you might consider alternative methods such as well-calibrated ensemble approaches or active learning to refine your p estimates.
What performance metrics should I focus on for an imbalanced problem?
Common metrics include precision, recall, F1-score, or PR AUC (precision-recall area under the curve). Accuracy is usually not informative when the class imbalance is severe. It might be beneficial to track metrics such as recall (to minimize false negatives) or precision (to avoid flooding the system with false positives), depending on the application’s risk tolerance.
Why do some people prefer weighting over downsampling?
Class weighting integrates the prevalence corrections directly into the training objective by imposing a heavier penalty on errors for the minority class. This allows you to train on the full dataset rather than discarding any majority-class samples. It can be particularly advantageous when you have enough computational resources to handle the original data or if you risk losing important examples by downsampling. However, weighting can sometimes lead to overfitting on minority classes, particularly if your minority class data is noisy.
How to validate if the corrected probabilities are well-calibrated?
You can use calibration curves (reliability diagrams) or calibration metrics such as Brier scores or the Expected Calibration Error (ECE). After training your model and applying probability corrections, evaluate these metrics on a validation or holdout set that reflects the real distribution. If there is a mismatch, you may need to fine-tune the correction factors or your original sampling strategy.
These follow-up explanations and potential pitfalls demonstrate the nuanced understanding that FANG-level interviews often require.
Below are additional follow-up questions
Does adjusting probabilities with this formula guarantee perfect calibration in practice?
Calibration adjustment is an important step, but it does not guarantee that your model will be perfectly calibrated on all data. The correction formula hinges on a few assumptions:
• Correct Priors: The adjustment assumes you know the true prevalence p exactly. If your estimate of p is off, your probabilities will still be misaligned. • Model Fit Quality: Even if you have the correct priors, the model itself might be underfitting or overfitting. If the logistic shape or the decision boundaries in the feature space do not reflect reality, then simply scaling odds will not fully fix miscalibration. • Consistent Data Distribution: If the data distribution has shifted from training to deployment (e.g., new user demographics or changes in how features are measured), the probabilities may again become miscalibrated.
In practice, it is helpful to verify calibration on a holdout set. If you observe miscalibration, you might apply additional calibration methods such as Platt scaling, isotonic regression, or temperature scaling on top of the rebalancing.
How do you handle changing sampling fractions during model development?
Sometimes the fraction of the negative class retained (i.e., how much you downsample) may shift across different training sessions or experiments:
• Dynamic Downsampling: If you decide to vary your sampling rate for any reason (e.g., trying a smaller or larger fraction of negatives), you must keep track of the new sampling ratio and recalculate q for your corrected probability formula. • Consistent Label Distribution Logging: Always log the distribution of positives vs. negatives in each training set. If you lose track of how many negatives you sampled, it becomes impossible to apply a principled correction. • Retraining vs. Correction Only: If you drastically change the sampling fraction, it might be necessary to re-tune hyperparameters or even re-check your model architecture. The act of changing sampling fractions can affect how the model’s learned decision boundary evolves.
Can rebalancing lead to overestimation of positive probabilities if the minority class is heterogeneous?
In highly imbalanced problems, the minority class itself can contain diverse subgroups. When you downsample the majority class, you risk:
• Clustering Effects: If certain sub-populations within the minority group dominate, your model may overweight those subgroups at the expense of others. • Limited Representation: The minority class might include atypical outliers or mislabeled samples that become more influential in a smaller dataset. Overestimating the probability of certain outliers can inflate overall positive probability predictions.
One mitigation approach is to ensure that, within the minority class, you have enough data or you oversample carefully (e.g., SMOTE) to maintain internal diversity. Pair this with robust validation checks to confirm that your post-training probability predictions are consistent across subgroups.
What happens if I want to set a custom threshold for classification instead of 0.5?
When using the adjusted probabilities, you may still decide to pick a decision threshold that diverges from 0.5 for various operational reasons. For example, you might use a threshold that balances precision and recall in a manner appropriate for your application:
• Thresholding vs. Calibration: Adjusting probabilities to account for the real-world class prior is a calibration step. After calibration, you can still shift your threshold up or down to favor precision or recall. The probabilities remain calibrated, but your decision boundary changes. • Impact of Shifting Thresholds: A threshold different from 0.5 does not break the logic of calibration; it simply chooses how you trade off false positives and false negatives. • Multiple Thresholds: In certain real-world scenarios, you might use different thresholds for different segments of your data. This can happen in settings with different costs or risk tolerances. Even in such cases, the core principle is to start with well-calibrated probabilities, then apply thresholding rules as business needs dictate.
How might rebalancing interact with a deep learning architecture where we are not explicitly outputting probabilities?
While many neural networks do output a probability-like score (e.g., via a sigmoid for binary classification), in some architectures you might have raw logits or specialized loss functions:
• Raw Logits: If your network outputs logits, you usually apply a sigmoid or softmax to convert logits into probabilities. You can still apply the rebalance correction on the sigmoid-transformed output. • Weighted Loss: Another strategy in neural networks is to incorporate class weights or focal loss. If you do so, the final logits might inherently reflect a certain weighting of classes. You still need to calibrate if the class distribution in your training set differs from production. • Post-Processing Pipeline: For complex pipelines where the neural network’s output is further transformed or integrated with other models, track how each stage might need rebalancing. If the final stage aggregates outputs from multiple sources, you might do a final recalibration step based on a combined holdout set that reflects the true distribution.
What if the cost of misclassifications is not symmetric?
In many business or medical applications, the cost of a false positive (labeling a negative sample as positive) is drastically different from the cost of a false negative (labeling a positive sample as negative):
• Effect on Probabilities vs. Thresholding: Probability calibration by rebalancing helps produce more accurate probability estimates for each class, but it does not inherently account for different misclassification costs. You still need to select a decision threshold that minimizes your expected cost. • Risk-Based Scoring: In high-stakes settings, you might compute an expected cost or expected utility based on the calibrated probability and each misclassification’s cost. A common approach is to choose a threshold t such that cost(false positive)*P(false positive) ~ cost(false negative)*P(false negative). • Rebalancing + Cost Sensitivity: You can combine rebalancing with cost-sensitive learning. For instance, in your training objective, weigh errors in the minority class more heavily if the cost of a false negative is particularly high. Even after that, you might still perform a final calibration step if the class distributions are heavily skewed in real-world data.
How do data preprocessing steps (feature scaling, missing value imputation, etc.) interact with downsampling?
Data preprocessing typically remains unaffected by the ratio in which you sample from each class, but a few subtleties exist:
• Distribution of Feature Values: Downsampling might cause a shift in certain feature statistics if the majority class has a wide range of values. For instance, if you randomly pick only 1% of negatives, you might lose rare sub-populations or extreme values in the negative class. • Missing Values: If your data cleaning or imputation strategy is learned from training data (e.g., computing mean or median of a feature in the training set), that statistic might be biased if your sampling procedure drastically reduces the negative samples. You can mitigate this by computing any necessary statistics or transformations before you downsample. • Normalization or Standardization: If you are standardizing features, be consistent. Compute means and standard deviations from the entire data (or at least from a representative subset) before applying the sampling step. Otherwise, your feature distributions might not match the real-world scenario.
How can we confirm that our rebalanced model performs better than a model trained on the original distribution?
One might question whether rebalancing truly adds value or whether the model trained on the original distribution (even if it sees 99.8% negatives) could suffice:
• Validation Strategy: One approach is to keep a proper validation set that reflects the real distribution (i.e., do not downsample in the validation set). Compare the performance metrics (e.g., AUC, PR AUC, F1-score) of a model trained on rebalanced data (with probability correction) vs. a model trained on the full data but perhaps with class weighting. • Metric Breakdown: Check not just global metrics but also calibration curves and confusion matrix metrics. A model might show a similar AUC but drastically different precision and recall distributions when you look at specific score thresholds. • Resource Constraints: Training on the full dataset might be computationally expensive. If rebalancing allows you to train more quickly without sacrificing performance, that can justify its usage.
If the data distribution changes drastically in production, is downsampling still valid?
Production data can evolve. For instance, a fraud detection system might see a shift in fraud prevalence:
• Need for Monitoring: Continuously monitor the distribution of your incoming data. If your fraction of positives changes from 0.2% to 1%, the old calibration factor is no longer accurate. • Adaptive Recalibration: You might implement a pipeline that re-estimates p on a rolling basis and applies a dynamic correction factor. This approach requires frequent retraining or re-calibration if the shift is significant. • Hybrid Approaches: In some extreme cases, you can combine rebalancing with active learning. The model can selectively request more labeled data for ambiguous or minority-class samples, ensuring that your representation of the minority class remains up-to-date.
How might explainability or model interpretability be affected by rebalancing?
• Feature Importance Shifts: Decision trees or gradient boosted machines might shift which features they consider most discriminative after rebalancing, because features that help distinguish the minority class become more emphasized. • Local Interpretations: Tools like LIME or SHAP rely on the local behavior of the model. When the model is trained on rebalanced data, the local neighborhoods for minority samples might differ from those in the original distribution. • Communication to Stakeholders: If stakeholders expect “raw” probabilities, you’ll need to explain that the model’s raw outputs are adjusted due to intentional downsampling. Presenting both the raw probability and the corrected probability may help illustrate the effect of rebalancing on interpretability.