forked from mayer79/ml_lecture
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path3_Trees.Rmd
451 lines (334 loc) · 22.7 KB
/
3_Trees.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
---
title: "Trees"
author: "Michael Mayer"
date: "`r Sys.Date()`"
output:
html_document:
toc: yes
toc_float: yes
number_sections: yes
df_print: paged
theme: paper
code_folding: show
math_method: katex
editor_options:
chunk_output_type: console
knit: (function(input, ...) {rmarkdown::render(input, output_dir = "../docs")})
---
```{r setup, include=FALSE}
knitr::opts_chunk$set(
echo = TRUE,
warning = FALSE,
message = FALSE
)
```
# Introduction
A decision tree is a simple, easy-to-interpret modeling technique for both regression and classification problems. Compared to other methods, decision trees usually do not perform very well. Their relevance lies in the fact that they are the building blocks of two of the most successful ML algorithms: random forests and gradient boosted trees. In this chapter, we will introduce these tree-based methods.
# Decision Trees
## How they work
On our journey to estimate the model $f$ by $\hat f$, we have considered mainly linear functions $f$ so far. We now move to a different function class: decision trees. They have been introduced in 1984 by Leo Breiman, Jerome Friedman and others [1] and are sometimes called "Classification and Regression Trees" (CART).
(Binary) decision trees are calculated recursively by partitioning the data in two pieces. Partitions are chosen to optimize the given average loss by asking the best "yes/no" question about the covariates, e.g., "is carat < 1?" or "is color better than F?".
For regression problems, the most frequently used loss function is the squared error.
For classification, its "information" (= cross-entropy = log loss = half the unit logistic deviance) or the very similar Gini impurity.
Predictions are calculated by sending an observation through the tree, starting with the question at the "trunk" and ending in a "leaf". The prediction is the value associated with the leaf. For regression situations, such leaf value typically equals the average response of all observations in the leaf. In classification settings, it may be the most frequent class in the leaf or all class probabilities.
The concept of a decision tree is best understood with an example.
## Example: decision tree
We will use the `dataCar` data set to predict the claim probability with a decision tree. As features, we use `veh_value`, `veh_body`, `veh_age`, `gender`, `area` and `agecat`.
```{r}
library(rpart)
library(rpart.plot)
library(insuranceData)
data(dataCar)
fit <- rpart(
clm ~ veh_value + veh_body + veh_age + gender + area + agecat,
data = dataCar,
method = "class",
parms = list(split = "information"),
xval = 0,
cp = -1,
maxdepth = 3
)
prp(fit, type = 2, extra = 7, shadow.col = "gray",
faclen = 0, box.palette = "auto", branch.type = 4,
varlen = 0, cex = 0.9, digits = 3, split.cex = 0.8)
dataCar[1, c("agecat", "veh_value", "veh_body")]
predict(fit, dataCar[1, ])
```
**Comments**
- The first observation belongs to a person in age category 2 and has a $10'600 hatchback: the first question sends us to the right, the second to the left and the third to the right. This gives us a claim probability of 6.7%.
- How was, e.g., the first question (`agecat >= 5`) chosen? The algorithm scans all covariates for all possible split positions and picks the one with best average loss improvement. In this case, splitting on covariate `agecat` at the value 5 reduced the average loss most.
**Properties of decision trees**
- **Outliers:** In contrast to linear models, outliers in covariates are not an issue because the algorithm only takes into account the sort order of the values. Similarly, taking logarithms in covariates has no effect. Both statements do not hold for the response variable.
- **Missing values:** Some implementations can deal with missing values in the input. Alternatively, missing values are often replaced by a typical value or a value smaller/larger than the smallest non-missing value (such as -1 for a positive variable).
- **Categorical covariates:** Unordered categorical covariates are tricky to split because with $\ell$ levels, theoretically, we end up with $2^\ell$ possible partitions. Try to lump small categories together or consider representing the levels by ordered categories (even if it does not make too much sense). One-hot-encoding is an option as well. Some algorithms offer ways to internally deal with unordered categoricals.
- **Greedy:** Partitions are made in a greedy way to optimize the objective in *one step*. Looking ahead more than one step would lead to better models but this is computationally too demanding in practice.
- **Interactions:** By their flexible structure, a decision tree can automatically capture interaction effects of any order (and other non-linear effects), at least if the data set is large and the tree is deep enough. This is a big advantage of tree-based methods compared to linear models where these elements have to carefully and manually be accounted for. In the next chapter, we will meet another model class with this advantage: neural nets.
- **Extrapolation:** By construction, a decision tree cannot extrapolate. So even a whopping 10 carat diamond cannot get a higher price prediction than the most expensive diamond in the training data.
These properties typically translate 1:1 to combinations of trees like random forests or boosted trees.
# Random Forests
## How they work
In 2001, Leo Breiman introduced a very powerful tree-based algorithm called *random forest*, see [2]. A random forest consists of many decision trees. To ensure that the trees differ, two sources or randomness are injected:
1. Each tree is calculated on a bootstrap sample of the data, i.e., on $n$ observations selected with replacement from the original $n$ rows. This technique is called "bagging", from "**b**ootstrap **agg**regat**ing**".
2. Each split scans only a random selection "mtry" of the $m$ covariates to find the best split, usually about $\sqrt{m}$ or $m/3$. "mtry" is the main tuning parameter of a random forest.
Predictions are found by pooling the predictions of all trees, e.g., by averaging or majority voting.
**Comments about random forests**
- **Number of trees:** Usually, 100-1000 trees are being grown. The more, the better. More trees also mean longer training time and larger models.
- **Diversification:** Single trees in a random forest are usually very deep. They overfit on the training data. It is the diversity across trees that produces a good and stable model, just with a well-diversified stock portfolio.
- **Never trust performance on the training set for random forests.**
- **OOB validation**: In each tree, about 1/3 of all observations are not in the bootstrap sample just by chance. Put differently: each observation is used in about 2/3 of all trees. If its prediction is calculated from the other 1/3 of the trees, we get an "out-of-sample" prediction, also called "out-of-bag" (OOB) prediction. If rows are independent, model performance derived from these OOB predictions is usually good enough to be used for model validation. Do not use OOB results when rows are dependent such as for grouped samples.
- **Parameter tuning:** Random forests offer many tuning parameters. Since the results typically do not depend too much on their choice, untuned random forests are often great benchmark models.
## Example: random forest
Let us now fit a random forest for diamond prices with typical parameters and 500 trees. 80% of the data is used for training, the other 20% we use for evaluating the performance. (Throughout the rest of the lecture, we will ignore the problematic aspect of having repeated rows for some diamonds.)
```{r}
library(ggplot2)
library(withr)
library(ranger)
library(MetricsWeighted)
library(hstats)
# Train/test split
with_seed(
9838,
ix <- sample(nrow(diamonds), 0.8 * nrow(diamonds))
)
fit <- ranger(
price ~ carat + color + cut + clarity,
num.trees = 500,
data = diamonds[ix, ],
importance = "impurity",
seed = 83
)
fit
# Performance on test data
pred <- predict(fit, diamonds[-ix, ])$predictions
rmse(diamonds$price[-ix], pred) # 553 USD
train_mean <- mean(diamonds[["price"]][ix])
r_squared(diamonds$price[-ix], pred, reference_mean = train_mean) # 0.9814
```
**Comments**
- Performance is excellent.
- The OOB estimate of performance is extremely close to the test set performance.
- Interpretation?
## Interpreting a "black box"
In contrast to a single decision tree or a linear model, a combination of many trees is not easy to interpret. It is good practice for any ML model to study at least *variable importance* and the strongest *effects*, not just its performance. A pure prediction machine is hardly of any interest and might even contain mistakes like using covariates derived from the response. Model interpretation helps to fight such problems and thus also to increase trust in a model.
### Variable importance
There are different approaches to measure the importance of a covariate. Since there is no general mathematical definition of "importance", the results of different approaches might be inconsistent across each other. For tree-based methods, a usual approach is to measure how many times a covariate $X$ was used in a split or how much total loss improvement came from splitting on $X$.
Approaches that work for *any* supervised model (including neural nets) include **permutation importance** and **SHAP importance**.
### Effects
One of the main reasons for the success of modern methods like random forests is the fact that they automatically learn interactions between two or more covariates. Thus, the effect of a covariate $X$ typically depends on the values of other covariates. In the extreme case, the effect of $X$ is different for each observation. The best what we can do is to study the *average effect* of $X$ over many observations, i.e., averaging the effects over interactions. This leads us to **partial dependence plots**: They work for any supervised ML model and are constructed as follows: A couple of observations are selected. Then, their average prediction is visualized against $X$ when sliding their value of $X$ over a reasonable grid of values, *keeping all other variables fixed*. The more natural the Ceteris Paribus clause, the more reliable the partial dependence plots.
Remark: A partial dependence plot of a covariate in a linear regression is simply a visualization of its coefficient.
Alternatives to partial dependence plots include **accumulated local effect plots** and **SHAP dependence plots**. Both relax the Ceteris Paribus clause.
## Example: random forest (continued)
For our last example, we will now look at variable importance and partial dependence plots.
```{r}
# Variable importance regarding MSE improvement
imp <- sort(importance(fit))
imp <- imp / sum(imp)
barplot(imp, horiz = TRUE, col = "chartreuse4")
# Partial dependence plots
for (v in c("carat", "color", "cut", "clarity")) {
p <- partial_dep(fit, v = v, X = diamonds[ix, ]) |>
plot() +
ggtitle(paste("PDP for", v))
print(p)
}
```
**Comments**
- As expected, `carat` is the most important predictor.
- All effects as assessed by partial dependence make sense.
## Exercises
1. In above example, replace carat by its logarithm. Do the results change compared to the example without logs?
2. Fit a random forest on the claims data for the binary variable `clm` using covariates `veh_value`, `veh_body`, `veh_age`, `gender`, `area`, and `agecat`. Choose a suitable tree depth either by cross-validation or by minimizing OOB error on the training data. Make sure to fit a *probability random forest*, i.e., predicting probabilities, not classes. Evaluate the final model on an independent test data set. (Note that the "ranger" package uses the "Brier score" as the evaluation metric for probabilistic predictions. In the binary case, is the same as the MSE.) Interpret the results by split gain importance and partial dependence plots.
# Gradient Boosted Trees
The idea of *boosting* was introduced by Schapire in 1990 [3] and roughly works as follows: A simple model is fit to the data. Then, another simple model is added, trying to correct the errors from the first model. This process is repeated until some stopping criterion triggers. As simple models or *base learners*, usually **small decision trees** are used, an idea pushed by Jerome Friedman in his famous 2001 article on the very general framework of gradient boosting machines [4].
Modern variants of such *gradient boosted trees* are [XGBoost](https://xgboost.readthedocs.io/en/latest/), [LightGBM](https://lightgbm.readthedocs.io/en/latest/) and [CatBoost](https://catboost.ai/). These are the predominant algorithms in ML competitions on tabular data, check [this comparison](https://github.com/mayer79/gradient_boosting_comparison) for differences with a screenshot as per Oct. 20, 2022:
![](../figs/comparison_boosting.PNG).
Predictions are calculated similar to random forests, i.e., by combining predictions of all trees. As loss/objective function, one can choose among many possibilities. Often, using the same loss function as a corresponding GLM is a good choice.
## Example: XGBoost
As an initial example on gradient boosting and XGBoost, we fit a model for diamond prices with squared error as loss function. The number of rounds/trees is initially chosen by cross-validation and early stopping, i.e., until CV validation (R)MSE stops improving for a couple or rounds. The learning rate (weight of each tree) is chosen by trial and error in order to end up with a reasonably small/large number of trees, see the next section for more details.
```{r}
library(ggplot2)
library(withr)
library(xgboost)
library(MetricsWeighted)
y <- "price"
xvars <- c("carat", "color", "cut", "clarity")
# Split into train and test
with_seed(
9838,
ix <- sample(nrow(diamonds), 0.8 * nrow(diamonds))
)
y_train <- diamonds[[y]][ix]
X_train <- diamonds[ix, xvars]
y_test <- diamonds[[y]][-ix]
X_test <- diamonds[-ix, xvars]
# XGBoost data interface
dtrain <- xgb.DMatrix(data.matrix(X_train), label = y_train)
# Minimal set of parameters
params <- list(
objective = "reg:squarederror",
learning_rate = 0.02
)
# Add trees until CV validation MSE stops improving over the last 20 rounds
cvm <- xgb.cv(
params = params,
data = dtrain,
nrounds = 5000,
nfold = 5,
early_stopping_rounds = 20,
showsd = FALSE,
print_every_n = 50
)
# Fit model on full training data using optimal number of boosting round
fit <- xgb.train(
params = params, data = dtrain, print_every_n = 50, nrounds = cvm$best_iteration
)
# Test performance
rmse(y_test, predict(fit, data.matrix(X_test))) # 541.2
```
**Comments:**
- More trees would mean better training performance, but worse CV performance.
- Test performance is slightly better than the random forest. Can we do even better?
- In the next example, we will also interpret such a model.
## Parameters of gradient boosted trees
Gradient boosted trees offer a quite a lot of parameters. Unlike with random forests, they need to be tuned to achieve good results. It would be naive to use an algorithm like XGBoost without parameter tuning. Here is a selection:
- **Number of boosting rounds:** In contrast to random forests, more trees/rounds is not always beneficial because the model begins to overfit after some time. The optimal number of rounds is usually found by early stopping, i.e., one lets the algorithm stop as soon as the (cross-)validation performance stops improving, see the example above.
- **Learning rate:** The learning rate determines training speed and the impact of each tree to the final model. Typical values are between 0.01 and 0.5. In practical applications, it is set to a value that leads to a reasonable amount of trees (100-1000). Usually, halving the learning rate means twice as much boosting rounds for comparable performance.
- **Regularization parameters:** Depending on the implementation, additional parameters are
- the tree depth (often 3-7) or the number of leaves (often 7-63),
- the strength of the L1 and/or L2 penalties added to the objective function (often between 0 and 5),
- the row subsampling rate (often between 0.8 and 1),
- the column subsampling rate (often between 0.6 and 1),
- ...
Reasonable regularization parameters are chosen by trial and error or systematically by randomized or grid search CV. Usually, it takes a couple of iterations until the range of the parameter values have been set appropriately.
Overall, the modelling strategy is as follows:
1. Using default regularization parameters, set the learning rate to get reasonable number of trees with CV-based early stopping.
2. Use randomized search CV with early stopping to tune regularization parameters such as tree depth. Iterate if needed.
3. Use the best parameter combination (incl. number of trees) to fit the model on the training data. "Best" typically means in terms of CV performance. As mentioned in the last chapter and depending on the situation, it can also mean "good CV performance and not too heavy overfit compared to insample performance" or some other reasonable criterion.
Note: Since learning rate, number of boosting rounds and regularization parameters are heavily interdependent, a "big" randomized grid search CV to choose learning rate, boosting rounds and regularization is often not ideal. Above suggestion (fix learning rate, select number of rounds by early stopping and do grid search only on regularization parameters) is more focussed, see also the example below.
## Example: XGBoost (fully tuned)
We will use XGBoost to fit diamond prices using the squared error as loss function and RMSE as evaluation metric, now using the tuning strategy outlined above.
```{r}
library(ggplot2)
library(withr)
library(xgboost)
library(MetricsWeighted)
library(hstats)
y <- "price"
xvars <- c("carat", "color", "cut", "clarity")
# Split into train and test
with_seed(
9838,
ix <- sample(nrow(diamonds), 0.8 * nrow(diamonds))
)
y_train <- diamonds[[y]][ix]
X_train <- diamonds[ix, xvars]
y_test <- diamonds[[y]][-ix]
X_test <- diamonds[-ix, xvars]
# XGBoost data interface
dtrain <- xgb.DMatrix(data.matrix(X_train), label = y_train)
# If grid search is to be run again, set tune <- TRUE
# Note that if run as rmarkdown, the path to the grid is "gridsearch",
# otherwise it is "r/gridsearch"
tune <- FALSE
if (tune) {
# Use default parameters to set learning rate with suitable number of rounds
params <- list(
objective = "reg:squarederror",
learning_rate = 0.02
)
# Cross-validation
cvm <- xgb.cv(
params = params,
data = dtrain,
nrounds = 5000,
nfold = 5,
early_stopping_rounds = 20,
showsd = FALSE,
print_every_n = 50
)
cvm # -> a lr of 0.02 provides about 370 trees, which is a convenient amount
# Final grid search after some iterations
grid <- expand.grid(
iteration = NA,
cv_score = NA,
train_score = NA,
objective = "reg:squarederror",
learning_rate = 0.02,
max_depth = 6:7,
reg_lambda = c(0, 2.5, 5, 7.5),
reg_alpha = c(0, 4),
colsample_bynode = c(0.8, 1),
subsample = c(0.8, 1),
min_split_loss = c(0, 1e-04),
min_child_weight = c(1, 10)
)
# Grid search or randomized search if grid is too large
max_size <- 32
grid_size <- nrow(grid)
if (grid_size > max_size) {
grid <- grid[sample(grid_size, max_size), ]
grid_size <- max_size
}
# Loop over grid and fit XGBoost with five-fold CV and early stopping
pb <- txtProgressBar(0, grid_size, style = 3)
for (i in seq_len(grid_size)) {
cvm <- xgb.cv(
params = as.list(grid[i, -(1:2)]),
data = dtrain,
nrounds = 5000,
nfold = 5,
early_stopping_rounds = 20,
verbose = 0
)
# Store result
grid[i, 1] <- cvm$best_iteration
grid[i, 2:3] <- cvm$evaluation_log[, c(4, 2)][cvm$best_iteration]
setTxtProgressBar(pb, i)
# Save grid to survive hard crashs
saveRDS(grid, file = "gridsearch/diamonds_xgb.rds")
}
}
# Load grid and select best iteration
grid <- readRDS("gridsearch/diamonds_xgb.rds")
grid <- grid[order(grid$cv_score), ]
head(grid)
# Fit final, tuned model
fit <- xgb.train(
params = as.list(grid[1, -(1:3)]),
data = dtrain,
nrounds = grid[1, "iteration"]
)
```
Now, the model is ready to be inspected by evaluating
- test performance,
- split gain importance and
- partial dependence plots.
```{r}
# Performance on test data
pred <- predict(fit, data.matrix(X_test))
rmse(y_test, pred) # 539.9
r_squared(y_test, pred, reference_mean = mean(y_train)) # 0.9823
# Variable importance regarding MSE improvement
imp <- xgb.importance(model = fit)
xgb.plot.importance(imp)
# Partial dependence plots
pred_fun <- function(m, X) predict(m, data.matrix(X))
for (v in xvars) {
p <- partial_dep(fit, v = v, X = X_train, pred_fun = pred_fun) |>
plot() +
ggtitle(paste("PDP for", v))
print(p)
}
```
**Comment**: The resulting model seems comparable to the random forest with slightly better performance. The grid search did not improve the results in this case.
## Exercises
1. Study the online documentation of XGBoost to figure out how to make the model monotonically increasing in carat. Test your insights without rerunning the grid search in our last example, i.e., just by refitting the final model. How does the partial dependence plot for `carat` look now?
2. Develop an XGBoost model for the claims data set with binary response `clm`, and covariates `veh_value`, `veh_body`, `veh_age`, `gender`, `area`, and `agecat`. Use a clean cross-validation/test approach. Use log loss as loss function and evaluation metric. Interpret its results. You don't need to write all the code from scratch, but rather modify the XGBoost code from the lecture notes.
3. Optional. Study the documentation of [LightGBM](https://lightgbm.readthedocs.io/en/latest/). Use LightGBM to develop a competitor to the XGBoost claims model from Exercise 2. The XGBoost code needs to be slightly adapted. Compare grid search time.
# Chapter Summary
In this chapter, we have met decision trees, random forests and tree boosting. Single decision trees are very easy to interpret but do not perform too well. On the other hand, tree ensembles like the random forest or gradient boosted trees usually perform very well but are tricky to interpret. We have introduced interpretation tools to look into such "black boxes". The main reason why random forests and boosted trees often provide more accurate models than a linear model lies in their ability to automatically learn interactions and other non-linear effects.
# Chapter References
[1] L. Breiman, J. Friedman, R. Olshen, and C. Stone, "Classification and Regression Trees", Wadsworth, Belmont, CA, 1984.
[2] L. Breiman, "Random forests". In: Machine Learning, 2001, 45(1).
[3] R. Schapire, "The strength of weak learnability", Machine Learning, Vol 5, Nr. 2, 1990.
[4] J. Friedman, "Greedy Function Approximation: A Gradient Boosting Machine", 2001.