Machine Learning with scikit-learn
Build, tune, and evaluate models using scikit-learn pipelines with reproducible ML workflows.
Content
Decision Trees and Forests
Versions:
Watch & Learn
AI-discovered learning video
Sign in to watch the learning video for this topic.
Decision Trees & Forests (scikit-learn): Intuition, Code, and When to Use Them
"Imagine playing 20 Questions with a robot — except each question is learned from data."
You're coming in hot from linear/logistic regression and regression metrics, and you already have the statistical intuition from earlier modules. Good. Trees and forests are the next logical step: flexible, non-linear models that answer "which question next?" at every split and, when assembled into forests, form a crowd that stabilizes wild decisions.
What these models are and why they matter
- Decision Trees: A tree is a flowchart-like model that splits the feature space into regions using simple decision rules (e.g., age > 30?). Each internal node asks a question, each branch is an answer, each leaf is a prediction. Works for classification and regression.
- Random Forests: An ensemble of many decision trees grown on bootstrapped samples and random feature subsets. Think: a jury of slightly biased experts whose majority vote is far less noisy than any single expert.
Why use them after regression? Because trees capture nonlinearities and interactions automatically — you don't have to hand-engineer polynomial terms or interaction features. And because you care about uncertainty, remember: forests reduce variance (they're less likely to overfit) but need calibration for probabilities.
Short analogy: 20 Questions & a Jury
- A single decision tree = one friend playing 20 Questions who gets dramatic and overfits to your last game.
- A random forest = 100 friends each playing different versions of the game, then voting — the crowd corrects individual weirdness.
This ties to the bias-variance tradeoff you saw earlier: deep trees = low bias, high variance; forests = still low bias, much lower variance.
How trees split: impurity & information
Micro explanation: The split goal
At each node the algorithm picks a feature and threshold to best reduce impurity (classification) or reduce variance (regression).
- For classification: Gini impurity or entropy (information gain).
- For regression: reduction in mean squared error (MSE) or variance.
This is exact application of your stats intuition: we're choosing splits that give the biggest drop in uncertainty.
scikit-learn: quick recipe (classification and regression)
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor, plot_tree
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
# Classification
clf = DecisionTreeClassifier(max_depth=5, random_state=42)
clf.fit(X_train, y_train)
probs = clf.predict_proba(X_test) # class probabilities
preds = clf.predict(X_test)
# Random forest
rf = RandomForestClassifier(n_estimators=100, max_depth=8, random_state=42, oob_score=True)
rf.fit(X_train, y_train)
print('OOB score:', rf.oob_score_)
# Regression
reg = DecisionTreeRegressor(min_samples_leaf=5, random_state=42)
reg.fit(X_train, y_train)
# Visualize a tree
import matplotlib.pyplot as plt
plt.figure(figsize=(12,8))
plot_tree(clf, feature_names=feature_names, class_names=class_names, filled=True, max_depth=3)
plt.show()
Notes:
- Use
max_depth,min_samples_split,min_samples_leaforccp_alpha(cost-complexity pruning) to control overfitting. RandomForestClassifier(..., oob_score=True)gives an out-of-bag estimate — a handy cross-validation-like score.
Interpreting trees: the sweet spot of interpretability
- Single trees are highly interpretable: you can trace a path and see decision rules.
- Forests are less transparent, but give feature importances (mean decrease impurity) and you can use permutation importance for reliable ranking.
- For calibrated probabilities, forests output class frequencies (predict_proba) but these can be poorly calibrated; use
CalibratedClassifierCVwhen you need trustworthy probabilities.
Practical tips & caveats (stats-savvy)
- Don't rely on a deep tree without validation. Deep trees memorize — your regression metrics (MSE, R2) on train/test will reveal this. Use cross-validation.
- Ensembles reduce variance, not bias. If your model has systematic bias (bad features, wrong problem), forests won't fix it.
- Feature scaling not required. Trees split on thresholds, so they don't need standardization like linear models.
- Categorical features: scikit-learn historically requires one-hot encoding; new versions (HistGradientBoosting) handle categoricals differently. Be aware: many-levell categorical variables can bias importance.
- Calibration & uncertainty: decision trees give class probability estimates from leaf counts — variance and class imbalance can harm calibration. Link back to your stats and probability lessons: treat these probabilities as estimates with sampling variability.
When to pick tree models vs linear models
Choose trees/forests when:
- Relationships are nonlinear or contain complex interactions.
- Interpretability at rule-level matters (single shallow tree).
- You need a strong baseline that's robust without heavy feature engineering.
Prefer linear/logistic regression when:
- The relationship is roughly linear, you want coefficients for inference, or you need highly interpretable parametric effects.
- You care about uncertainty quantification tied to statistical inference (p-values, confidence intervals). Trees are less naturally suited for classical inferential statistics.
Example workflow (from data -> evaluation)
- Split data, keep test set separate.
- Train a shallow DecisionTree to inspect rules and features.
- Train a RandomForest for stable prediction; tune
n_estimators,max_depth,min_samples_leafvia CV. - Evaluate with appropriate metrics (accuracy / precision/recall / ROC-AUC for classification; MSE / R2 for regression). You already practiced these in Regression Metrics.
- Check calibration and feature importance. If interpretability is needed, consider SHAP values or a single surrogate tree.
Quick checklist: hyperparameters that actually matter
max_depth— prevents runaway growth. Great first knob.min_samples_leaf/min_samples_split— smooth predictions, reduce overfitting.n_estimators(forest) — more trees -> lower variance (diminishing returns).max_features(forest) — controls tree diversity; common defaults: sqrt(n_features) for classification.bootstrap— if False, you're building a forest without resampling (less variance reduction).ccp_alpha— cost complexity pruning parameter to simplify trees.
Key takeaways
- Decision trees are intuitive, handle nonlinearities and interactions, and are easy to visualize.
- Random forests combine many trees to reduce variance and improve robustness.
- Use your regression metrics and CV discipline from earlier: monitor overfitting, compare on the test set, and validate probability estimates if you care about uncertainty.
Final memorable insight: A single tree is like a confident person who’s often wrong; a forest is the committee smarter about not being spectacularly wrong.
If you want, I can: provide a ready-to-run notebook example (Iris + Titanic), show how to tune a forest with GridSearchCV, or demonstrate permutation importance and SHAP explanations. Which would you like next?
Comments (0)
Please sign in to leave a comment.
No comments yet. Be the first to comment!