Skip to content

Commit

Permalink
Fix the issue facebook#1814
Browse files Browse the repository at this point in the history
I did as PyStanBackend. And now when we use the method fit of Prophet, we can do like in the documentation:

https://facebook.github.io/prophet/docs/additional_topics.html#updating-fitted-models

def stan_init(m):
    """Retrieve parameters from a trained model.

    Retrieve parameters from a trained model in the format
    used to initialize a new Stan model.

    Parameters
    ----------
    m: A trained model of the Prophet class.

    Returns
    -------
    A Dictionary containing retrieved parameters of m.

    """
    res = {}
    for pname in ['k', 'm', 'sigma_obs']:
        res[pname] = m.params[pname][0][0]
    for pname in ['delta', 'beta']:
        res[pname] = m.params[pname][0]
    return res

df = pd.read_csv('../examples/example_wp_log_peyton_manning.csv')
df1 = df.loc[df['ds'] < '2016-01-19', :]  # All data except the last day
m1 = Prophet().fit(df1) # A model fit to all data except the last day

%timeit m2 = Prophet().fit(df)  # Adding the last day, fitting from scratch
%timeit m2 = Prophet().fit(df, init=stan_init(m1))  # Adding the last day, warm-starting from m1

Update models.py

Update models.py

Update models.py

Update models.py

Update models.py

Update models.py

Update models.py

Test

Test2

Test4

Test4

Test are fixed
  • Loading branch information
loulo1 committed Mar 9, 2021
1 parent 8882c6a commit fc8fa49
Showing 1 changed file with 35 additions and 19 deletions.
54 changes: 35 additions & 19 deletions python/fbprophet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,30 @@ def load_model(self):

def fit(self, stan_init, stan_data, **kwargs):
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'algorithm' not in kwargs:
kwargs['algorithm'] = 'Newton' if stan_data['T'] < 100 else 'LBFGS'
iterations = int(1e4)
if 'init' in kwargs:
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]

args = dict(
data=stan_data,
init=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
iter=int(1e4),
)
args.update(kwargs)

args['inits'] = args['init']
del args['init']

try:
self.stan_fit = self.model.optimize(data=stan_data,
inits=stan_init,
iter=iterations,
**kwargs)
self.stan_fit = self.model.optimize(**args)
except RuntimeError as e:
# Fall back on Newton
if self.newton_fallback and kwargs['algorithm'] != 'Newton':
if self.newton_fallback and args['algorithm'] != 'Newton':
logger.warning(
'Optimization terminated abnormally. Falling back to Newton.'
)
kwargs['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(data=stan_data,
inits=stan_init,
iter=iterations,
**kwargs)
args['algorithm'] = 'Newton'
self.stan_fit = self.model.optimize(**args)
else:
raise e

Expand All @@ -117,17 +122,28 @@ def fit(self, stan_init, stan_data, **kwargs):

def sampling(self, stan_init, stan_data, samples, **kwargs) -> dict:
(stan_init, stan_data) = self.prepare_data(stan_init, stan_data)
if 'init' in kwargs:
kwargs['init'] = self.prepare_data(kwargs['init'], stan_data)[0]

args = dict(
data=stan_data,
init=stan_init,
algorithm='Newton' if stan_data['T'] < 100 else 'LBFGS',
)

if 'chains' not in kwargs:
kwargs['chains'] = 4
iter_half = samples // 2
kwargs['iter_sampling'] = iter_half
if 'iter_warmup' not in kwargs:
kwargs['iter_warmup'] = iter_half

args.update(kwargs)

self.stan_fit = self.model.sample(data=stan_data,
inits=stan_init,
iter_sampling=iter_half,
**kwargs)
args['inits'] = args['init']
del args['init']

self.stan_fit = self.model.sample(**args)
res = self.stan_fit.sample
(samples, c, columns) = res.shape
res = res.reshape((samples * c, columns))
Expand Down Expand Up @@ -166,10 +182,10 @@ def prepare_data(init, data) -> Tuple[dict, dict]:
'm': init['m'],
'delta': init['delta'].tolist(),
'beta': init['beta'].tolist(),
'sigma_obs': 1
'sigma_obs': init['sigma_obs']
}
return (cmdstanpy_init, cmdstanpy_data)

@staticmethod
def stan_to_dict_numpy(column_names: Tuple[str, ...], data: 'np.array'):
import numpy as np
Expand Down

0 comments on commit fc8fa49

Please sign in to comment.