Skip to content

Commit

Permalink
Update models.py (AI4Finance-Foundation#1189)
Browse files Browse the repository at this point in the history
* Update models.py

Added Support for Ensembling with TD3 and A2C.
Streamlined the Training of the Windows and consolidated the model's parameters and variables.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tricao7 and pre-commit-ci[bot] authored Mar 23, 2024
1 parent 5fc7fee commit 8b57ea3
Showing 1 changed file with 119 additions and 202 deletions.
321 changes: 119 additions & 202 deletions finrl/agents/stablebaselines3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,19 +349,101 @@ def DRL_prediction(
df_last_state.to_csv(f"results/last_state_{name}_{i}.csv", index=False)
return last_state

def _train_window(
self,
model_name,
model_kwargs,
sharpe_list,
validation_start_date,
validation_end_date,
timesteps_dict,
i,
validation,
turbulence_threshold,
):
"""
Train the model for a single window.
"""
if model_kwargs is None:
return None, sharpe_list, -1

print(f"======{model_name} Training========")
model = self.get_model(
model_name, self.train_env, policy="MlpPolicy", model_kwargs=model_kwargs
)
model = self.train_model(
model,
model_name,
tb_log_name=f"{model_name}_{i}",
iter_num=i,
total_timesteps=timesteps_dict[model_name],
) # 100_000
print(
f"======{model_name} Validation from: ",
validation_start_date,
"to ",
validation_end_date,
)
val_env = DummyVecEnv(
[
lambda: StockTradingEnv(
df=validation,
stock_dim=self.stock_dim,
hmax=self.hmax,
initial_amount=self.initial_amount,
num_stock_shares=[0] * self.stock_dim,
buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,
sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,
reward_scaling=self.reward_scaling,
state_space=self.state_space,
action_space=self.action_space,
tech_indicator_list=self.tech_indicator_list,
turbulence_threshold=turbulence_threshold,
iteration=i,
model_name=model_name,
mode="validation",
print_verbosity=self.print_verbosity,
)
]
)
val_obs = val_env.reset()
self.DRL_validation(
model=model,
test_data=validation,
test_env=val_env,
test_obs=val_obs,
)
sharpe = self.get_validation_sharpe(i, model_name=model_name)
print(f"{model_name} Sharpe Ratio: ", sharpe)
sharpe_list.append(sharpe)
return model, sharpe_list, sharpe

def run_ensemble_strategy(
self, A2C_model_kwargs, PPO_model_kwargs, DDPG_model_kwargs, timesteps_dict
self,
A2C_model_kwargs,
PPO_model_kwargs,
DDPG_model_kwargs,
SAC_model_kwargs,
TD3_model_kwargs,
timesteps_dict,
):
"""Ensemble Strategy that combines PPO, A2C and DDPG"""
# Model Parameters
kwargs = {
"a2c": A2C_model_kwargs,
"ppo": PPO_model_kwargs,
"ddpg": DDPG_model_kwargs,
"sac": SAC_model_kwargs,
"td3": TD3_model_kwargs,
}
# Model Sharpe Ratios
model_dct = {k: {"sharpe_list": [], "sharpe": -1} for k in MODELS.keys()}

"""Ensemble Strategy that combines A2C, PPO, DDPG, SAC, and TD3"""
print("============Start Ensemble Strategy============")
# for ensemble model, it's necessary to feed the last state
# of the previous model to the current model as the initial state
last_state_ensemble = []

ppo_sharpe_list = []
ddpg_sharpe_list = []
a2c_sharpe_list = []

model_use = []
validation_start_date_list = []
validation_end_date_list = []
Expand Down Expand Up @@ -489,159 +571,24 @@ def run_ensemble_strategy(
)
# print("training: ",len(data_split(df, start=20090000, end=test.datadate.unique()[i-rebalance_window]) ))
# print("==============Model Training===========")
print("======A2C Training========")
model_a2c = self.get_model(
"a2c", self.train_env, policy="MlpPolicy", model_kwargs=A2C_model_kwargs
)
model_a2c = self.train_model(
model_a2c,
"a2c",
tb_log_name=f"a2c_{i}",
iter_num=i,
total_timesteps=timesteps_dict["a2c"],
) # 100_000

print(
"======A2C Validation from: ",
validation_start_date,
"to ",
validation_end_date,
)
val_env_a2c = DummyVecEnv(
[
lambda: StockTradingEnv(
df=validation,
stock_dim=self.stock_dim,
hmax=self.hmax,
initial_amount=self.initial_amount,
num_stock_shares=[0] * self.stock_dim,
buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,
sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,
reward_scaling=self.reward_scaling,
state_space=self.state_space,
action_space=self.action_space,
tech_indicator_list=self.tech_indicator_list,
turbulence_threshold=turbulence_threshold,
iteration=i,
model_name="A2C",
mode="validation",
print_verbosity=self.print_verbosity,
)
]
)
val_obs_a2c = val_env_a2c.reset()
self.DRL_validation(
model=model_a2c,
test_data=validation,
test_env=val_env_a2c,
test_obs=val_obs_a2c,
)
sharpe_a2c = self.get_validation_sharpe(i, model_name="A2C")
print("A2C Sharpe Ratio: ", sharpe_a2c)

print("======PPO Training========")
model_ppo = self.get_model(
"ppo", self.train_env, policy="MlpPolicy", model_kwargs=PPO_model_kwargs
)
model_ppo = self.train_model(
model_ppo,
"ppo",
tb_log_name=f"ppo_{i}",
iter_num=i,
total_timesteps=timesteps_dict["ppo"],
) # 100_000
print(
"======PPO Validation from: ",
validation_start_date,
"to ",
validation_end_date,
)
val_env_ppo = DummyVecEnv(
[
lambda: StockTradingEnv(
df=validation,
stock_dim=self.stock_dim,
hmax=self.hmax,
initial_amount=self.initial_amount,
num_stock_shares=[0] * self.stock_dim,
buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,
sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,
reward_scaling=self.reward_scaling,
state_space=self.state_space,
action_space=self.action_space,
tech_indicator_list=self.tech_indicator_list,
turbulence_threshold=turbulence_threshold,
iteration=i,
model_name="PPO",
mode="validation",
print_verbosity=self.print_verbosity,
)
]
)
val_obs_ppo = val_env_ppo.reset()
self.DRL_validation(
model=model_ppo,
test_data=validation,
test_env=val_env_ppo,
test_obs=val_obs_ppo,
)
sharpe_ppo = self.get_validation_sharpe(i, model_name="PPO")
print("PPO Sharpe Ratio: ", sharpe_ppo)

print("======DDPG Training========")
model_ddpg = self.get_model(
"ddpg",
self.train_env,
policy="MlpPolicy",
model_kwargs=DDPG_model_kwargs,
)
model_ddpg = self.train_model(
model_ddpg,
"ddpg",
tb_log_name=f"ddpg_{i}",
iter_num=i,
total_timesteps=timesteps_dict["ddpg"],
) # 50_000
print(
"======DDPG Validation from: ",
validation_start_date,
"to ",
validation_end_date,
)
val_env_ddpg = DummyVecEnv(
[
lambda: StockTradingEnv(
df=validation,
stock_dim=self.stock_dim,
hmax=self.hmax,
initial_amount=self.initial_amount,
num_stock_shares=[0] * self.stock_dim,
buy_cost_pct=[self.buy_cost_pct] * self.stock_dim,
sell_cost_pct=[self.sell_cost_pct] * self.stock_dim,
reward_scaling=self.reward_scaling,
state_space=self.state_space,
action_space=self.action_space,
tech_indicator_list=self.tech_indicator_list,
turbulence_threshold=turbulence_threshold,
iteration=i,
model_name="DDPG",
mode="validation",
print_verbosity=self.print_verbosity,
)
]
)
val_obs_ddpg = val_env_ddpg.reset()
self.DRL_validation(
model=model_ddpg,
test_data=validation,
test_env=val_env_ddpg,
test_obs=val_obs_ddpg,
)
sharpe_ddpg = self.get_validation_sharpe(i, model_name="DDPG")

ppo_sharpe_list.append(sharpe_ppo)
a2c_sharpe_list.append(sharpe_a2c)
ddpg_sharpe_list.append(sharpe_ddpg)
# Train Each Model
for model_name in MODELS.keys():
# Train The Model
model, sharpe_list, sharpe = self._train_window(
model_name,
kwargs[model_name],
model_dct[model_name]["sharpe_list"],
validation_start_date,
validation_end_date,
timesteps_dict,
i,
validation,
turbulence_threshold,
)
# Save the model's sharpe ratios, and the model itself
model_dct[model_name]["sharpe_list"] = sharpe_list
model_dct[model_name]["model"] = model
model_dct[model_name]["sharpe"] = sharpe

print(
"======Best Model Retraining from: ",
Expand All @@ -665,46 +612,12 @@ def run_ensemble_strategy(
# print_verbosity=self.print_verbosity
# )])
# Model Selection based on sharpe ratio
if (sharpe_ppo >= sharpe_a2c) & (sharpe_ppo >= sharpe_ddpg):
model_use.append("PPO")
model_ensemble = model_ppo

# model_ensemble = self.get_model("ppo",
# self.train_full_env,
# policy="MlpPolicy",
# model_kwargs=PPO_model_kwargs)
# model_ensemble = self.train_model(model_ensemble,
# "ensemble",
# tb_log_name="ensemble_{}".format(i),
# iter_num = i,
# total_timesteps=timesteps_dict['ppo']) #100_000
elif (sharpe_a2c > sharpe_ppo) & (sharpe_a2c > sharpe_ddpg):
model_use.append("A2C")
model_ensemble = model_a2c

# model_ensemble = self.get_model("a2c",
# self.train_full_env,
# policy="MlpPolicy",
# model_kwargs=A2C_model_kwargs)
# model_ensemble = self.train_model(model_ensemble,
# "ensemble",
# tb_log_name="ensemble_{}".format(i),
# iter_num = i,
# total_timesteps=timesteps_dict['a2c']) #100_000
else:
model_use.append("DDPG")
model_ensemble = model_ddpg

# model_ensemble = self.get_model("ddpg",
# self.train_full_env,
# policy="MlpPolicy",
# model_kwargs=DDPG_model_kwargs)
# model_ensemble = self.train_model(model_ensemble,
# "ensemble",
# tb_log_name="ensemble_{}".format(i),
# iter_num = i,
# total_timesteps=timesteps_dict['ddpg']) #50_000

# Same order as MODELS: {"a2c": A2C, "ddpg": DDPG, "td3": TD3, "sac": SAC, "ppo": PPO}
sharpes = [model_dct[k]["sharpe"] for k in MODELS.keys()]
# Find the model with the highest sharpe ratio
max_mod = list(MODELS.keys())[np.argmax(sharpes)]
model_use.append(max_mod.upper())
model_ensemble = model_dct[max_mod]["model"]
# Training and Validation ends

# Trading starts
Expand Down Expand Up @@ -734,9 +647,11 @@ def run_ensemble_strategy(
validation_start_date_list,
validation_end_date_list,
model_use,
a2c_sharpe_list,
ppo_sharpe_list,
ddpg_sharpe_list,
model_dct["a2c"]["sharpe_list"],
model_dct["ppo"]["sharpe_list"],
model_dct["ddpg"]["sharpe_list"],
model_dct["sac"]["sharpe_list"],
model_dct["td3"]["sharpe_list"],
]
).T
df_summary.columns = [
Expand All @@ -747,6 +662,8 @@ def run_ensemble_strategy(
"A2C Sharpe",
"PPO Sharpe",
"DDPG Sharpe",
"SAC Sharpe",
"TD3 Sharpe",
]

return df_summary

0 comments on commit 8b57ea3

Please sign in to comment.