-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathlive.py
138 lines (117 loc) · 4.9 KB
/
live.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
import torch
import matplotlib.pyplot as plt
import functools
def live(agent, environment, num_episodes, max_timesteps,
verbose=False, print_every=10):
"""
Logic for operating over episodes.
max_timesteps is maximum number of time steps per episode.
"""
observation_data = [] #(self.timestamp, self.state, self.price_record)
action_data = []
rewards = []
if verbose:
print("agent: %s, number of episodes: %d" % (str(agent), num_episodes))
for episode in range(num_episodes):
agent.reset_cumulative_reward()
new_env = environment.reset()
observation_history = [(new_env[0],new_env[1],new_env[2], False)]
# observation_history = [(environment.reset()[0],environment.reset()[1],environment.reset()[2], False)]
action_history = []
t = 0
done = False
while not done:
action = agent.act(observation_history, action_history)
# action = 0
timestamp, state, price_record, done = environment.step(action)
action_history.append(action)
observation_history.append((timestamp, state, price_record, done))
t += 1
done = done or (t == max_timesteps)
agent.update_buffer(observation_history, action_history)
# print('uploading')
agent.learn_from_buffer()
# print('learning')
observation_data.append(observation_history)
action_data.append(action_history)
rewards.append(agent.cummulative_reward)
if verbose and (episode % print_every == 0):
print("ep %d, reward %.5f" % (episode, agent.cummulative_reward))
print('short', action_history.count(0))
print('neutral', action_history.count(1))
print('long', action_history.count(2))
if episode % (5 * print_every) == 0:
test(agent, environment, 10, max_timesteps)
return observation_data, action_data, rewards
def test(agent, environment, num_episodes, max_timesteps, verbose=True, print_every=10):
observation_data = [] #(self.timestamp, self.state, self.price_record)
action_data = []
rewards = []
agent.test_mode = True
print("agent: %s, number of episodes: %d" % (str(agent), num_episodes))
print (20*"=", "start testing on eval set", 20*"=")
for episode in range(10):
agent.reset_cumulative_reward()
new_env = environment.reset_fixed(episode * 3600)
observation_history = [(new_env[0], new_env[1], new_env[2], False)]
# observation_history = [(environment.reset()[0],environment.reset()[1],environment.reset()[2], False)]
action_history = []
t = 0
done = False
while not done:
action = agent.act(observation_history, action_history)
# action = 0
timestamp, state, price_record, done = environment.step(action)
action_history.append(action)
observation_history.append((timestamp, state, price_record, done))
t += 1
done = done or (t == max_timesteps)
agent.update_buffer(observation_history, action_history)
# print('uploading')
#agent.learn_from_buffer()
# print('learning')
observation_data.append(observation_history)
action_data.append(action_history)
rewards.append(agent.cummulative_reward)
if verbose and (episode % print_every == 0):
print("ep %d, reward %.5f" % (episode, agent.cummulative_reward))
print('short', action_history.count(0))
print('neutral', action_history.count(1))
print('long', action_history.count(2))
print ('The sum of the reward is {}'.format(np.sum(np.array(rewards))))
print (20 * "=", "finishing testing on eval set...", 20 * "=")
return observation_data, action_data, rewards
### Example of usage
from environment import ForexEnv
from agents import RandomAgent
from agents import DQNAgent
from agents import Forex_reward_function
from feature import ForexIdentityFeature
if __name__=='__main__':
np.random.seed(0)
torch.manual_seed(0)
env = ForexEnv()
agent = DQNAgent(
action_set=[0, 1, 2],
reward_function=functools.partial(Forex_reward_function),
feature_extractor=ForexIdentityFeature(),
hidden_dims=[50, 50],
learning_rate=5e-4,
buffer_size=5000,
batch_size=12,
num_batches=100,
starts_learning=5000,
final_epsilon=0.02,
discount=0.99,
target_freq=10,
verbose=False,
print_every=10)
observation_data, action_data, rewards = live(
agent=agent,
environment=env,
num_episodes=5,
max_timesteps=5,
verbose=True,
print_every=50)
agent.save('./dqn.pt')