ML Interview Q Series: How do Classification And Regression Trees generate decision boundaries for classification, and how do they generate numeric predictions for regression?
📚 Browse the full ML Interview series here.
Comprehensive Explanation
The CART (Classification And Regression Trees) algorithm uses a binary tree structure to split data recursively and make final predictions either as class labels (in classification) or as numerical values (in regression). Although both classification and regression trees share the same high-level methodology, they differ in the splitting criteria and the way the leaf nodes produce outputs.
Building Classification Trees
When producing a classification tree, CART seeks to partition the feature space in such a way that each partition (leaf node) is as "pure" as possible in terms of class distribution. It uses impurity measures like the Gini index or (less commonly) entropy. The most common measure of impurity in CART is the Gini index.
Below is the core mathematical expression for the Gini index in a leaf node containing multiple classes:
Where p_k represents the fraction of samples belonging to class k in the node. This means that if all samples in a node belong to exactly one class, the Gini index is 0, indicating maximum purity. As class proportions spread out more uniformly, the Gini index increases.
CART searches over possible split points across features to find the split that best reduces the impurity from the parent node to its children nodes. It does this by considering all candidate features and all possible thresholds (in case of continuous features) or subsets (in case of categorical features) that partition the data into two child nodes. The chosen split is the one that maximizes the decrease in impurity. This is often referred to as the Gini gain or decrease in impurity.
The process continues recursively, creating new branches and child nodes until a stopping criterion is met, such as:
The node is pure enough.
The node has too few samples.
The maximum tree depth is reached.
The improvement in impurity reduction is below a threshold.
The final leaf nodes of the classification tree represent the predicted class for samples that fall into those regions.
Building Regression Trees
When building a regression tree, the objective is to predict a continuous value. Instead of the Gini index or entropy, the CART algorithm typically uses mean squared error (or sometimes mean absolute error, depending on the implementation) to measure the homogeneity of the target variable in a node.
Below is the core mathematical expression commonly used for regression tree splitting, where MSE is minimized:
Here y_i is the actual value for sample i, and hat{y}_{i} is the predicted value (often the mean of the target values in the node). N is the total number of samples in the node. The CART algorithm tries different split points to minimize the MSE or maximize the reduction in MSE between parent and child nodes.
After the tree is fully grown, each leaf node holds the mean of the target values of the samples that fall into that leaf. When a new input is passed through the tree, the final prediction is the mean value of the corresponding leaf node.
Algorithmic Overview
CART follows a greedy, top-down search:
Start with the entire dataset in the root node.
Evaluate all possible splits and select the best one based on either Gini impurity decrease (for classification) or MSE reduction (for regression).
Recursively repeat this process on the resulting child nodes.
Stop the process when a stopping criterion is triggered or if further splits no longer offer meaningful improvement.
The final tree can later be pruned (reduced in size) to avoid overfitting and generalize better.
Example in Python
Below is a minimal Python snippet using scikit-learn for building both a classification tree and a regression tree.
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris, load_boston # load_boston is deprecated in newer versions; for demonstration only
# Classification Tree Example
iris = load_iris()
X_clf_train, X_clf_test, y_clf_train, y_clf_test = train_test_split(
iris.data, iris.target, test_size=0.2, random_state=42
)
clf = DecisionTreeClassifier(criterion='gini', max_depth=3, random_state=42)
clf.fit(X_clf_train, y_clf_train)
print("Classification Tree Accuracy:", clf.score(X_clf_test, y_clf_test))
# Regression Tree Example
boston = load_boston()
X_reg_train, X_reg_test, y_reg_train, y_reg_test = train_test_split(
boston.data, boston.target, test_size=0.2, random_state=42
)
reg = DecisionTreeRegressor(criterion='mse', max_depth=3, random_state=42)
reg.fit(X_reg_train, y_reg_train)
print("Regression Tree R^2:", reg.score(X_reg_test, y_reg_test))
No matter the library you use, the main idea remains consistent: classification trees aim to reduce class impurity, whereas regression trees aim to minimize the variance of numeric values in the leaf nodes.
What if the data is very high-dimensional?
When the data has a large number of features, the splitting process becomes more computationally expensive. One must be mindful of:
Potential overfitting when the tree becomes overly complex.
The curse of dimensionality leading to many splits that might not generalize.
Effective feature selection or dimensionality reduction can help.
How does pruning help with overfitting?
Pruning involves removing sub-trees that contribute more to overfitting than to improving predictive performance. It can be done by:
Pre-pruning: stopping early if the split's improvement is not sufficient.
Post-pruning: growing the full tree and then cutting back branches that do not improve validation set performance.
This helps avoid an overly deep tree that memorizes noise in the training data.
Could we use different impurity measures?
Yes. For classification, besides the Gini index, one might use:
Entropy (from information theory)
Misclassification error For regression, one might use:
Mean absolute error
A custom loss function that matches the problem’s needs
Each choice can slightly alter the shape of the learned tree, but the overall methodology remains the same.
Are there any limitations of CART?
CART can handle tabular datasets effectively, but it may exhibit:
Instability if small changes in data lead to different splits.
Overfitting when allowed to grow without constraints.
Inability to capture linear relationships as efficiently as linear models.
Bias toward features with many distinct values (especially with minimal regularization).
Ensembling methods like Random Forests or Gradient Boosted Trees often mitigate these limitations by combining multiple trees or sequentially refining them.
How do we handle missing values?
In many implementations, CART can handle missing data via:
Surrogate splits: using other features that correlate strongly with the primary split.
Skipping samples with missing values or imputing them.
The chosen strategy can impact performance, so it must be applied carefully, especially in real-world settings with significant amounts of missing data.
What if the target variable for regression is heavily skewed?
If the target has a heavy-tailed distribution, the mean might not be the best central measure. One approach:
Transform the target variable (e.g., log transform).
Use a more robust loss function such as mean absolute error.
This can provide splits better suited to capturing the core structure of the data.
How do we interpret decision trees in a business setting?
Decision trees are relatively easy to interpret:
Each internal node is a question (split) about a feature threshold.
Each path from root to leaf can be viewed as a set of rules for classification or prediction.
One can visualize the tree, enabling non-technical stakeholders to understand the logic.
That interpretability is one of the main advantages over more complex models like neural networks.
Could ensembles of trees be preferable?
Ensembles such as Random Forests or Gradient Boosted Trees often achieve higher accuracy or better generalization by:
Averaging multiple tree predictions (bagging in Random Forest).
Sequentially improving tree ensembles by focusing on residual errors (boosting).
They typically outperform a single tree but sacrifice some interpretability.
How can hyperparameters affect performance?
Hyperparameters such as max_depth, min_samples_split, min_samples_leaf, and criterion (Gini, entropy, MSE, etc.) can significantly influence:
Complexity and depth of the tree.
How well the model generalizes.
The resulting interpretability.
Careful tuning of these hyperparameters is crucial and is usually done via cross-validation.
How does CART differ from ID3, C4.5, or C5.0?
ID3, C4.5, and C5.0 are decision tree methods that often use information gain based on entropy for classification. CART typically uses Gini index, but the core idea of recursive splitting remains the same. C4.5 and successors can handle continuous features, missing values, and pruning in specialized ways, whereas CART is more streamlined but similarly flexible.
What is the final takeaway about CART?
CART constructs binary decision trees that are conceptually straightforward to interpret and implement. For classification, it reduces class impurity, while for regression, it reduces variance or error in numeric outputs. Proper pruning, hyperparameter tuning, and careful handling of data are vital for achieving optimal performance. The algorithm remains a cornerstone for many ensemble methods, making it essential for a robust understanding of modern machine learning approaches.