ML Interview Q Series: How does the concept of a Random Forest build upon the foundations of Decision Trees?
📚 Browse the full ML Interview series here.
Comprehensive Explanation
A Random Forest is fundamentally an ensemble method that leverages multiple Decision Trees to produce more robust and accurate predictions. Decision Trees themselves are single-model structures used for classification or regression, but they often suffer from high variance (i.e., they are prone to overfitting). By aggregating the outputs of many Decision Trees that are trained in parallel on slightly different subsets of data and features, Random Forests achieve better generalization and lower variance.
Decision Trees
A Decision Tree is a model that splits the input feature space into regions, where each split is chosen based on a criterion (for classification, criteria like Gini impurity or Entropy are commonly used; for regression, criteria like Mean Squared Error are used). A simple example of the Gini impurity metric for classification is shown below.
Here, D is the dataset at a particular node, C is the number of classes, and p_i is the proportion of samples belonging to the i-th class within that node. The lower the impurity, the more homogeneous the node.
Within a Decision Tree, we repeatedly partition the data based on the feature or threshold that yields the greatest reduction in impurity. While this process can fit data extremely well, a single tree can easily overfit, particularly if it grows without constraints.
Random Forest Formation
In a Random Forest, we create multiple Decision Trees, each trained on different subsets (sampled with replacement, also known as bagging) of the training data. Additionally, when building each tree, only a random subset of features is considered for each split. This randomness ensures the trees are de-correlated and do not all focus on the same dominant features. The final prediction is typically the majority vote (in classification) or the average (in regression) of the individual trees.
Below is a generic representation of how the ensemble prediction in a Random Forest (with M trees) is computed for a regression scenario:
Where x is the input feature vector, T_i(x) is the prediction of the i-th Decision Tree, and M is the total number of trees in the Random Forest.
Because each tree is trained on a slightly different subset of data, and each node split in every tree considers a random subset of features, the Random Forest introduces diversity among the trees. This diversity makes the ensemble robust to noise and helps reduce overfitting compared to using a single Decision Tree.
Practical Insights and Implementation Details
One major advantage is that Random Forests require relatively less tuning compared to many other advanced models. However, there are several parameters that can significantly affect performance:
• Number of trees: Increasing this usually improves performance up to a point, but also increases training time. • Maximum depth of each tree: Deeper trees can capture more complex patterns but may risk overfitting if other forms of regularization (like restricting minimum samples per leaf) are not used. • Number of features randomly chosen at each split: This controls how correlated the trees are. A lower value increases diversity among trees but may reduce the strength of individual trees. • Bootstrapping size: Typically, each tree gets as many samples (with replacement) as the size of the original dataset, but variations are possible.
Random Forests also allow measuring feature importance by tracking how much each feature contributes to reducing node impurity across all the trees. This can be valuable for interpretability and feature selection.
Example in Python
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
# Load sample data
X, y = load_iris(return_X_y=True)
# Split into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# Create the Random Forest model
rf_model = RandomForestClassifier(n_estimators=100, max_depth=None, max_features='sqrt', random_state=42)
# Train the model
rf_model.fit(X_train, y_train)
# Predict on test
y_pred = rf_model.predict(X_test)
# Evaluate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("Test Accuracy:", accuracy)
In this example, we create a RandomForestClassifier
with 100 trees and let it grow to maximum depth, using the square-root of the total number of features at each split. This is a common setting for classification tasks.
Possible Follow-Up Questions
How does bagging in Random Forests help reduce overfitting?
Bagging (bootstrap aggregating) helps reduce overfitting by training each tree on a slightly different subset of data sampled with replacement. This naturally injects variability in the training sets, so each tree learns different patterns. When their predictions are aggregated, random fluctuations or overfitting tendencies of individual trees are smoothed out, yielding lower variance overall. The effect is further enhanced by randomly selecting a subset of features at each split.
Why do we randomly choose a subset of features for each split in a Random Forest?
If every tree considered all features for splitting, a few strongly predictive features might dominate the splits across all trees, causing them to become more correlated. This defeats the purpose of using an ensemble, as correlated models tend to make similar mistakes. By restricting each split to a random subset of features, we encourage the trees to learn diverse patterns, which increases the overall ensemble’s robustness.
How do Random Forests provide an estimate of feature importance?
During training, each node split reduces the impurity (e.g., Gini or entropy for classification) or variance (for regression). We can sum up all the impurity reductions attributed to a feature across all trees and normalize this sum to get a measure of how important each feature has been in partitioning the data. This measure can be averaged and scaled to compare features directly. Another approach is permutation-based importance, which involves randomly permuting the values of a feature and assessing how much the model’s accuracy degrades.
What is out-of-bag (OOB) error and how is it used?
In bagging, each tree is trained on a bootstrap sample, typically the same size as the original dataset but with replacement. On average, each bootstrap sample will contain approximately 63% unique instances from the original data, leaving around 37% of instances (the “out-of-bag” samples) not selected. Each tree can be evaluated on its OOB samples, providing an unbiased estimate of its performance without needing a separate validation set. By aggregating the performance across all OOB samples and trees, we get the overall OOB error estimate for the Random Forest.
Can Random Forests handle high-dimensional data?
Random Forests often perform well in high-dimensional scenarios because the random feature selection at each split mitigates the risk of focusing on a small set of dominant features. However, for extremely high-dimensional problems with limited data, performance can still degrade if there are many noisy features. Techniques like feature selection or dimensionality reduction can still be beneficial.
Does increasing the number of trees in a Random Forest always improve performance?
Adding more trees generally reduces the variance in predictions, up to a certain point. However, after a large enough ensemble has been constructed, the performance gains from adding more trees tend to plateau. Furthermore, computational cost grows with more trees, so in practice, there’s a trade-off between performance, training time, and inference time.
In what scenarios might a single Decision Tree be preferable over a Random Forest?
A single Decision Tree is much faster to train and is easier to interpret. In scenarios where model interpretability in a strictly hierarchical splitting sense is paramount and the dataset is relatively small or simple, a single tree might suffice. However, if performance and generalization are priorities, a Random Forest usually outperforms a single tree.