Skip to content

Commit

Permalink
Enable creation of custom performance metrics (facebook#2599)
Browse files Browse the repository at this point in the history
  • Loading branch information
hub-bla authored Oct 18, 2024
1 parent 2a57e9d commit 691a464
Show file tree
Hide file tree
Showing 8 changed files with 624 additions and 240 deletions.
242 changes: 171 additions & 71 deletions docs/_docs/diagnostics.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,45 +86,45 @@ df_cv.head()
<tr>
<th>0</th>
<td>2010-02-16</td>
<td>8.959678</td>
<td>8.470035</td>
<td>9.451618</td>
<td>8.954582</td>
<td>8.462876</td>
<td>9.452305</td>
<td>8.242493</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>1</th>
<td>2010-02-17</td>
<td>8.726195</td>
<td>8.236734</td>
<td>9.219616</td>
<td>8.720932</td>
<td>8.222682</td>
<td>9.242788</td>
<td>8.008033</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>2</th>
<td>2010-02-18</td>
<td>8.610011</td>
<td>8.104834</td>
<td>9.125484</td>
<td>8.604608</td>
<td>8.066920</td>
<td>9.144968</td>
<td>8.045268</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>3</th>
<td>2010-02-19</td>
<td>8.532004</td>
<td>7.985031</td>
<td>9.041575</td>
<td>8.526379</td>
<td>8.029189</td>
<td>9.043045</td>
<td>7.928766</td>
<td>2010-02-15</td>
</tr>
<tr>
<th>4</th>
<td>2010-02-20</td>
<td>8.274090</td>
<td>7.779034</td>
<td>8.745627</td>
<td>8.268247</td>
<td>7.749520</td>
<td>8.741847</td>
<td>7.745003</td>
<td>2010-02-15</td>
</tr>
Expand Down Expand Up @@ -154,6 +154,20 @@ df_cv2 = cross_validation(m, cutoffs=cutoffs, horizon='365 days')
The `performance_metrics` utility can be used to compute some useful statistics of the prediction performance (`yhat`, `yhat_lower`, and `yhat_upper` compared to `y`), as a function of the distance from the cutoff (how far into the future the prediction was). The statistics computed are mean squared error (MSE), root mean squared error (RMSE), mean absolute error (MAE), mean absolute percent error (MAPE), median absolute percent error (MDAPE) and coverage of the `yhat_lower` and `yhat_upper` estimates. These are computed on a rolling window of the predictions in `df_cv` after sorting by horizon (`ds` minus `cutoff`). By default 10% of the predictions will be included in each window, but this can be changed with the `rolling_window` argument.



In Python, you can also create custom performance metric using the `register_performance_metric` decorator. Created metric should contain following arguments:

- df: Cross-validation results dataframe.

- w: Aggregation window size.



and return:

- Dataframe with columns horizon and metric.


```R
# R
df.p <- performance_metrics(df.cv)
Expand Down Expand Up @@ -200,57 +214,143 @@ df_p.head()
<tr>
<th>0</th>
<td>37 days</td>
<td>0.493764</td>
<td>0.702683</td>
<td>0.504754</td>
<td>0.058485</td>
<td>0.049922</td>
<td>0.058774</td>
<td>0.674052</td>
<td>0.493358</td>
<td>0.702395</td>
<td>0.503977</td>
<td>0.058376</td>
<td>0.049365</td>
<td>0.058677</td>
<td>0.676565</td>
</tr>
<tr>
<th>1</th>
<td>38 days</td>
<td>0.499522</td>
<td>0.706769</td>
<td>0.509723</td>
<td>0.059060</td>
<td>0.049389</td>
<td>0.059409</td>
<td>0.672910</td>
<td>0.499112</td>
<td>0.706478</td>
<td>0.508946</td>
<td>0.058951</td>
<td>0.049135</td>
<td>0.059312</td>
<td>0.675423</td>
</tr>
<tr>
<th>2</th>
<td>39 days</td>
<td>0.521614</td>
<td>0.722229</td>
<td>0.515793</td>
<td>0.059657</td>
<td>0.049540</td>
<td>0.060131</td>
<td>0.670169</td>
<td>0.521344</td>
<td>0.722042</td>
<td>0.515016</td>
<td>0.059547</td>
<td>0.049225</td>
<td>0.060034</td>
<td>0.672682</td>
</tr>
<tr>
<th>3</th>
<td>40 days</td>
<td>0.528760</td>
<td>0.727159</td>
<td>0.518634</td>
<td>0.059961</td>
<td>0.049232</td>
<td>0.060504</td>
<td>0.671311</td>
<td>0.528651</td>
<td>0.727084</td>
<td>0.517873</td>
<td>0.059852</td>
<td>0.049072</td>
<td>0.060409</td>
<td>0.676336</td>
</tr>
<tr>
<th>4</th>
<td>41 days</td>
<td>0.536078</td>
<td>0.732174</td>
<td>0.519585</td>
<td>0.060036</td>
<td>0.049389</td>
<td>0.060641</td>
<td>0.678849</td>
<td>0.536149</td>
<td>0.732222</td>
<td>0.518843</td>
<td>0.059927</td>
<td>0.049135</td>
<td>0.060548</td>
<td>0.681361</td>
</tr>
</tbody>
</table>
</div>



```python
# Python
from prophet.diagnostics import register_performance_metric, rolling_mean_by_h
import numpy as np
@register_performance_metric
def mase(df, w):
"""Mean absolute scale error
Parameters
----------
df: Cross-validation results dataframe.
w: Aggregation window size.
Returns
-------
Dataframe with columns horizon and mase.
"""
e = (df['y'] - df['yhat'])
d = np.abs(np.diff(df['y'])).sum()/(df['y'].shape[0]-1)
se = np.abs(e/d)
if w < 0:
return pd.DataFrame({'horizon': df['horizon'], 'mase': se})
return rolling_mean_by_h(
x=se.values, h=df['horizon'].values, w=w, name='mase'
)

df_mase = performance_metrics(df_cv, metrics=['mase'])
df_mase.head()
```



<div>
<style scoped>
.dataframe tbody tr th:only-of-type {
vertical-align: middle;
}

.dataframe tbody tr th {
vertical-align: top;
}

.dataframe thead th {
text-align: right;
}
</style>
<table border="1" class="dataframe">
<thead>
<tr style="text-align: right;">
<th></th>
<th>horizon</th>
<th>mase</th>
</tr>
</thead>
<tbody>
<tr>
<th>0</th>
<td>37 days</td>
<td>0.522946</td>
</tr>
<tr>
<th>1</th>
<td>38 days</td>
<td>0.528102</td>
</tr>
<tr>
<th>2</th>
<td>39 days</td>
<td>0.534401</td>
</tr>
<tr>
<th>3</th>
<td>40 days</td>
<td>0.537365</td>
</tr>
<tr>
<th>4</th>
<td>41 days</td>
<td>0.538372</td>
</tr>
</tbody>
</table>
Expand All @@ -271,7 +371,7 @@ from prophet.plot import plot_cross_validation_metric
fig = plot_cross_validation_metric(df_cv, metric='mape')
```

![png](/prophet/static/diagnostics_files/diagnostics_17_0.png)
![png](/prophet/static/diagnostics_files/diagnostics_18_0.png)


The size of the rolling window in the figure can be changed with the optional argument `rolling_window`, which specifies the proportion of forecasts to use in each rolling window. The default is 0.1, corresponding to 10% of rows from `df_cv` included in each window; increasing this will lead to a smoother average curve in the figure. The `initial` period should be long enough to capture all of the components of the model, in particular seasonalities and extra regressors: at least a year for yearly seasonality, at least a week for weekly seasonality, etc.
Expand Down Expand Up @@ -355,33 +455,33 @@ for params in all_params:
tuning_results = pd.DataFrame(all_params)
tuning_results['rmse'] = rmses
print(tuning_results)
```
changepoint_prior_scale seasonality_prior_scale rmse
0 0.001 0.01 0.757694
1 0.001 0.10 0.743399
2 0.001 1.00 0.753387
3 0.001 10.00 0.762890
4 0.010 0.01 0.542315
5 0.010 0.10 0.535546
6 0.010 1.00 0.527008
7 0.010 10.00 0.541544
8 0.100 0.01 0.524835
9 0.100 0.10 0.516061
10 0.100 1.00 0.521406
11 0.100 10.00 0.518580
12 0.500 0.01 0.532140
13 0.500 0.10 0.524668
14 0.500 1.00 0.521130
15 0.500 10.00 0.522980

```

changepoint_prior_scale seasonality_prior_scale rmse
0 0.001 0.01 0.757694
1 0.001 0.10 0.743399
2 0.001 1.00 0.753387
3 0.001 10.00 0.762890
4 0.010 0.01 0.542315
5 0.010 0.10 0.535546
6 0.010 1.00 0.527008
7 0.010 10.00 0.541544
8 0.100 0.01 0.524835
9 0.100 0.10 0.516061
10 0.100 1.00 0.521406
11 0.100 10.00 0.518580
12 0.500 0.01 0.532140
13 0.500 0.10 0.524668
14 0.500 1.00 0.521130
15 0.500 10.00 0.522980

```python
# Python
best_params = all_params[np.argmin(rmses)]
print(best_params)
```
{'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.1}

{'changepoint_prior_scale': 0.1, 'seasonality_prior_scale': 0.1}

Alternatively, parallelization could be done across parameter combinations by parallelizing the loop above.

Expand Down
Binary file removed docs/static/diagnostics_files/diagnostics_16_0.png
Binary file not shown.
Binary file modified docs/static/diagnostics_files/diagnostics_17_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/static/diagnostics_files/diagnostics_4_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit 691a464

Please sign in to comment.