Skip to content

Commit 6c2c6c8

Browse files
committedDec 2, 2020
Chapter 5: Solving RL
1 parent f3024be commit 6c2c6c8

File tree

1 file changed

+2117
-0
lines changed

1 file changed

+2117
-0
lines changed
 

‎Chapter05/Solving RL.ipynb

+2,117
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2117 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np\n",
10+
"import gym"
11+
]
12+
},
13+
{
14+
"cell_type": "markdown",
15+
"metadata": {},
16+
"source": [
17+
"# Dynamic Programming"
18+
]
19+
},
20+
{
21+
"cell_type": "code",
22+
"execution_count": 18,
23+
"metadata": {},
24+
"outputs": [],
25+
"source": [
26+
"class FoodTruck(gym.Env):\n",
27+
" def __init__(self):\n",
28+
" self.v_demand = [100, 200, 300, 400]\n",
29+
" self.p_demand = [0.3, 0.4, 0.2, 0.1]\n",
30+
" self.capacity = self.v_demand[-1]\n",
31+
" self.days = ['Mon', 'Tue', 'Wed', \n",
32+
" 'Thu', 'Fri', \"Weekend\"]\n",
33+
" self.unit_cost = 4\n",
34+
" self.net_revenue = 7\n",
35+
" self.action_space = [0, 100, 200, 300, 400]\n",
36+
" self.state_space = [(\"Mon\", 0)] \\\n",
37+
" + [(d, i) for d in self.days[1:] \n",
38+
" for i in [0, 100, 200, 300]]\n",
39+
" \n",
40+
" def get_next_state_reward(self, state, action, demand):\n",
41+
" day, inventory = state\n",
42+
" result = {}\n",
43+
" result['next_day'] = self.days[self.days.index(day) \\\n",
44+
" + 1]\n",
45+
" result['starting_inventory'] = min(self.capacity, \n",
46+
" inventory \n",
47+
" + action)\n",
48+
" result['cost'] = self.unit_cost * action \n",
49+
" result['sales'] = min(result['starting_inventory'], \n",
50+
" demand)\n",
51+
" result['revenue'] = self.net_revenue * result['sales']\n",
52+
" result['next_inventory'] \\\n",
53+
" = result['starting_inventory'] - result['sales']\n",
54+
" result['reward'] = result['revenue'] - result['cost']\n",
55+
" return result\n",
56+
" \n",
57+
" def get_transition_prob(self, state, action):\n",
58+
" next_s_r_prob = {}\n",
59+
" for ix, demand in enumerate(self.v_demand):\n",
60+
" result = self.get_next_state_reward(state, \n",
61+
" action, \n",
62+
" demand)\n",
63+
" next_s = (result['next_day'],\n",
64+
" result['next_inventory'])\n",
65+
" reward = result['reward']\n",
66+
" prob = self.p_demand[ix]\n",
67+
" if (next_s, reward) not in next_s_r_prob:\n",
68+
" next_s_r_prob[next_s, reward] = prob\n",
69+
" else:\n",
70+
" next_s_r_prob[next_s, reward] += prob\n",
71+
" return next_s_r_prob\n",
72+
" \n",
73+
" def reset(self):\n",
74+
" self.day = \"Mon\"\n",
75+
" self.inventory = 0\n",
76+
" state = (self.day, self.inventory)\n",
77+
" return state\n",
78+
" \n",
79+
" def is_terminal(self, state):\n",
80+
" day, inventory = state\n",
81+
" if day == \"Weekend\":\n",
82+
" return True\n",
83+
" else:\n",
84+
" return False\n",
85+
" \n",
86+
" def step(self, action):\n",
87+
" demand = np.random.choice(self.v_demand, \n",
88+
" p=self.p_demand)\n",
89+
" result = self.get_next_state_reward((self.day, \n",
90+
" self.inventory), \n",
91+
" action, \n",
92+
" demand)\n",
93+
" self.day = result['next_day']\n",
94+
" self.inventory = result['next_inventory']\n",
95+
" state = (self.day, self.inventory)\n",
96+
" reward = result['reward']\n",
97+
" done = self.is_terminal(state)\n",
98+
" info = {'demand': demand, 'sales': result['sales']}\n",
99+
" return state, reward, done, info"
100+
]
101+
},
102+
{
103+
"cell_type": "code",
104+
"execution_count": 19,
105+
"metadata": {},
106+
"outputs": [
107+
{
108+
"data": {
109+
"text/plain": [
110+
"2590.83"
111+
]
112+
},
113+
"execution_count": 19,
114+
"metadata": {},
115+
"output_type": "execute_result"
116+
}
117+
],
118+
"source": [
119+
"# Simulating an arbitrary policy\n",
120+
"np.random.seed(0)\n",
121+
"foodtruck = FoodTruck()\n",
122+
"rewards = []\n",
123+
"for i_episode in range(10000):\n",
124+
" state = foodtruck.reset()\n",
125+
" done = False\n",
126+
" ep_reward = 0\n",
127+
" while not done:\n",
128+
" day, inventory = state\n",
129+
" action = max(0, 300 - inventory)\n",
130+
" state, reward, done, info = foodtruck.step(action) \n",
131+
" ep_reward += reward\n",
132+
" rewards.append(ep_reward)\n",
133+
"np.mean(rewards)"
134+
]
135+
},
136+
{
137+
"cell_type": "code",
138+
"execution_count": null,
139+
"metadata": {},
140+
"outputs": [],
141+
"source": [
142+
"# Single day expected reward\n",
143+
"ucost = 4\n",
144+
"uprice = 7\n",
145+
"v_demand = [100, 200, 300, 400]\n",
146+
"p_demand = [0.3, 0.4, 0.2, 0.1]\n",
147+
"inv = 400\n",
148+
"profit = uprice*np.sum([p_demand[i]*min(v_demand[i], inv) for i in range(4)]) - inv*ucost\n",
149+
"print(profit)"
150+
]
151+
},
152+
{
153+
"cell_type": "markdown",
154+
"metadata": {},
155+
"source": [
156+
"## Policy Evaluation"
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": 21,
162+
"metadata": {},
163+
"outputs": [],
164+
"source": [
165+
"def base_policy(states):\n",
166+
" policy = {}\n",
167+
" for s in states:\n",
168+
" day, inventory = s\n",
169+
" prob_a = {} \n",
170+
" if inventory >= 300:\n",
171+
" prob_a[0] = 1\n",
172+
" else:\n",
173+
" prob_a[200 - inventory] = 0.5\n",
174+
" prob_a[300 - inventory] = 0.5\n",
175+
" policy[s] = prob_a\n",
176+
" return policy"
177+
]
178+
},
179+
{
180+
"cell_type": "code",
181+
"execution_count": 22,
182+
"metadata": {},
183+
"outputs": [],
184+
"source": [
185+
"def expected_update(env, v, s, prob_a, gamma):\n",
186+
" expected_value = 0\n",
187+
" for a in prob_a:\n",
188+
" prob_next_s_r = env.get_transition_prob(s, a)\n",
189+
" for next_s, r in prob_next_s_r:\n",
190+
" expected_value += prob_a[a] \\\n",
191+
" * prob_next_s_r[next_s, r] \\\n",
192+
" * (r + gamma * v[next_s])\n",
193+
" return expected_value"
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": 26,
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"def policy_evaluation(env, policy, max_iter=100, \n",
203+
" v = None, eps=0.1, gamma=1):\n",
204+
" if not v:\n",
205+
" v = {s: 0 for s in env.state_space}\n",
206+
" k = 0\n",
207+
" while True:\n",
208+
" max_delta = 0\n",
209+
" for s in v:\n",
210+
" if not env.is_terminal(s):\n",
211+
" v_old = v[s]\n",
212+
" prob_a = policy[s]\n",
213+
" v[s] = expected_update(env, v, \n",
214+
" s, prob_a, \n",
215+
" gamma)\n",
216+
" max_delta = max(max_delta, \n",
217+
" abs(v[s] - v_old))\n",
218+
" k += 1\n",
219+
" if max_delta < eps:\n",
220+
" print(\"Converged in\", k, \"iterations.\")\n",
221+
" break\n",
222+
" elif k == max_iter:\n",
223+
" print(\"Terminating after\", k, \"iterations.\")\n",
224+
" break\n",
225+
" return v"
226+
]
227+
},
228+
{
229+
"cell_type": "code",
230+
"execution_count": 52,
231+
"metadata": {},
232+
"outputs": [],
233+
"source": [
234+
"foodtruck = FoodTruck()\n",
235+
"policy = base_policy(foodtruck.state_space)"
236+
]
237+
},
238+
{
239+
"cell_type": "code",
240+
"execution_count": 53,
241+
"metadata": {},
242+
"outputs": [
243+
{
244+
"name": "stdout",
245+
"output_type": "stream",
246+
"text": [
247+
"Converged in 6 iterations.\n",
248+
"Expected weekly profit: 2515.0\n"
249+
]
250+
}
251+
],
252+
"source": [
253+
"v = policy_evaluation(foodtruck, policy)\n",
254+
"print(\"Expected weekly profit:\", v[\"Mon\", 0])"
255+
]
256+
},
257+
{
258+
"cell_type": "code",
259+
"execution_count": 54,
260+
"metadata": {},
261+
"outputs": [
262+
{
263+
"name": "stdout",
264+
"output_type": "stream",
265+
"text": [
266+
"The state values:\n"
267+
]
268+
},
269+
{
270+
"data": {
271+
"text/plain": [
272+
"{('Mon', 0): 2515.0,\n",
273+
" ('Tue', 0): 1960.0,\n",
274+
" ('Tue', 100): 2360.0,\n",
275+
" ('Tue', 200): 2760.0,\n",
276+
" ('Tue', 300): 3205.0,\n",
277+
" ('Wed', 0): 1405.0,\n",
278+
" ('Wed', 100): 1805.0,\n",
279+
" ('Wed', 200): 2205.0,\n",
280+
" ('Wed', 300): 2650.0,\n",
281+
" ('Thu', 0): 850.0000000000001,\n",
282+
" ('Thu', 100): 1250.0,\n",
283+
" ('Thu', 200): 1650.0,\n",
284+
" ('Thu', 300): 2095.0,\n",
285+
" ('Fri', 0): 295.00000000000006,\n",
286+
" ('Fri', 100): 695.0000000000001,\n",
287+
" ('Fri', 200): 1095.0,\n",
288+
" ('Fri', 300): 1400.0,\n",
289+
" ('Weekend', 0): 0,\n",
290+
" ('Weekend', 100): 0,\n",
291+
" ('Weekend', 200): 0,\n",
292+
" ('Weekend', 300): 0}"
293+
]
294+
},
295+
"execution_count": 54,
296+
"metadata": {},
297+
"output_type": "execute_result"
298+
}
299+
],
300+
"source": [
301+
"print(\"The state values:\")\n",
302+
"v"
303+
]
304+
},
305+
{
306+
"cell_type": "code",
307+
"execution_count": 57,
308+
"metadata": {},
309+
"outputs": [],
310+
"source": [
311+
"def choose_action(state, policy):\n",
312+
" prob_a = policy[state]\n",
313+
" action = np.random.choice(a=list(prob_a.keys()), \n",
314+
" p=list(prob_a.values()))\n",
315+
" return action\n",
316+
"\n",
317+
"def simulate_policy(policy, n_episodes):\n",
318+
" np.random.seed(0)\n",
319+
" foodtruck = FoodTruck()\n",
320+
" rewards = []\n",
321+
" for i_episode in range(n_episodes):\n",
322+
" state = foodtruck.reset()\n",
323+
" done = False\n",
324+
" ep_reward = 0\n",
325+
" while not done:\n",
326+
" action = choose_action(state, policy)\n",
327+
" state, reward, done, info = foodtruck.step(action) \n",
328+
" ep_reward += reward\n",
329+
" rewards.append(ep_reward)\n",
330+
" print(\"Expected weekly profit:\", np.mean(rewards))"
331+
]
332+
},
333+
{
334+
"cell_type": "code",
335+
"execution_count": 58,
336+
"metadata": {},
337+
"outputs": [
338+
{
339+
"name": "stdout",
340+
"output_type": "stream",
341+
"text": [
342+
"Expected weekly profit: 2518.1\n"
343+
]
344+
}
345+
],
346+
"source": [
347+
"simulate_policy(policy, 1000)"
348+
]
349+
},
350+
{
351+
"cell_type": "markdown",
352+
"metadata": {},
353+
"source": [
354+
"## Policy Iteration"
355+
]
356+
},
357+
{
358+
"cell_type": "code",
359+
"execution_count": 39,
360+
"metadata": {},
361+
"outputs": [],
362+
"source": [
363+
"def policy_improvement(env, v, s, actions, gamma):\n",
364+
" prob_a = {}\n",
365+
" if not env.is_terminal(s):\n",
366+
" max_q = np.NINF\n",
367+
" best_a = None\n",
368+
" for a in actions:\n",
369+
" q_sa = expected_update(env, v, s, {a: 1}, gamma)\n",
370+
" if q_sa >= max_q:\n",
371+
" max_q = q_sa\n",
372+
" best_a = a\n",
373+
" prob_a[best_a] = 1\n",
374+
" else:\n",
375+
" max_q = 0\n",
376+
" return prob_a, max_q"
377+
]
378+
},
379+
{
380+
"cell_type": "code",
381+
"execution_count": 42,
382+
"metadata": {},
383+
"outputs": [],
384+
"source": [
385+
"def policy_iteration(env, eps=0.1, gamma=1):\n",
386+
" np.random.seed(1)\n",
387+
" states = env.state_space\n",
388+
" actions = env.action_space\n",
389+
" policy = {s: {np.random.choice(actions): 1}\n",
390+
" for s in states}\n",
391+
" v = {s: 0 for s in states}\n",
392+
" while True:\n",
393+
" v = policy_evaluation(env, policy, v=v, \n",
394+
" eps=eps, gamma=gamma)\n",
395+
" old_policy = policy\n",
396+
" policy = {}\n",
397+
" for s in states:\n",
398+
" policy[s], _ = policy_improvement(env, v, s, \n",
399+
" actions, gamma)\n",
400+
" if old_policy == policy:\n",
401+
" break\n",
402+
" print(\"Optimal policy found!\")\n",
403+
" return policy, v"
404+
]
405+
},
406+
{
407+
"cell_type": "code",
408+
"execution_count": 43,
409+
"metadata": {},
410+
"outputs": [
411+
{
412+
"name": "stdout",
413+
"output_type": "stream",
414+
"text": [
415+
"Converged in 6 iterations.\n",
416+
"Converged in 6 iterations.\n",
417+
"Converged in 5 iterations.\n",
418+
"Optimal policy found!\n",
419+
"Expected weekly profit: 2880.0\n"
420+
]
421+
}
422+
],
423+
"source": [
424+
"policy, v = policy_iteration(foodtruck)\n",
425+
"print(\"Expected weekly profit:\", v[\"Mon\", 0])"
426+
]
427+
},
428+
{
429+
"cell_type": "code",
430+
"execution_count": 44,
431+
"metadata": {},
432+
"outputs": [
433+
{
434+
"name": "stdout",
435+
"output_type": "stream",
436+
"text": [
437+
"{('Mon', 0): {400: 1}, ('Tue', 0): {400: 1}, ('Tue', 100): {300: 1}, ('Tue', 200): {200: 1}, ('Tue', 300): {100: 1}, ('Wed', 0): {400: 1}, ('Wed', 100): {300: 1}, ('Wed', 200): {200: 1}, ('Wed', 300): {100: 1}, ('Thu', 0): {300: 1}, ('Thu', 100): {200: 1}, ('Thu', 200): {100: 1}, ('Thu', 300): {0: 1}, ('Fri', 0): {200: 1}, ('Fri', 100): {100: 1}, ('Fri', 200): {0: 1}, ('Fri', 300): {0: 1}, ('Weekend', 0): {}, ('Weekend', 100): {}, ('Weekend', 200): {}, ('Weekend', 300): {}}\n"
438+
]
439+
}
440+
],
441+
"source": [
442+
"print(policy)"
443+
]
444+
},
445+
{
446+
"cell_type": "code",
447+
"execution_count": null,
448+
"metadata": {},
449+
"outputs": [],
450+
"source": []
451+
},
452+
{
453+
"cell_type": "markdown",
454+
"metadata": {},
455+
"source": [
456+
"## Value Iteration"
457+
]
458+
},
459+
{
460+
"cell_type": "code",
461+
"execution_count": 49,
462+
"metadata": {},
463+
"outputs": [],
464+
"source": [
465+
"def value_iteration(env, max_iter=100, eps=0.1, gamma=1):\n",
466+
" states = env.state_space\n",
467+
" actions = env.action_space\n",
468+
" v = {s: 0 for s in states}\n",
469+
" policy = {}\n",
470+
" k = 0\n",
471+
" while True:\n",
472+
" max_delta = 0\n",
473+
" for s in states:\n",
474+
" old_v = v[s]\n",
475+
" policy[s], v[s] = policy_improvement(env, \n",
476+
" v, \n",
477+
" s, \n",
478+
" actions, \n",
479+
" gamma)\n",
480+
" max_delta = max(max_delta, abs(v[s] - old_v))\n",
481+
" k += 1\n",
482+
" if max_delta < eps:\n",
483+
" print(\"Converged in\", k, \"iterations.\")\n",
484+
" break\n",
485+
" elif k == max_iter:\n",
486+
" print(\"Terminating after\", k, \"iterations.\")\n",
487+
" break\n",
488+
" return policy, v"
489+
]
490+
},
491+
{
492+
"cell_type": "code",
493+
"execution_count": 48,
494+
"metadata": {},
495+
"outputs": [
496+
{
497+
"name": "stdout",
498+
"output_type": "stream",
499+
"text": [
500+
"Converged in 6 iterations.\n",
501+
"6\n",
502+
"Expected weekly profit: 2880.0\n"
503+
]
504+
}
505+
],
506+
"source": [
507+
"policy, v = value_iteration(foodtruck)\n",
508+
"print(\"Expected weekly profit:\", v[\"Mon\", 0])"
509+
]
510+
},
511+
{
512+
"cell_type": "code",
513+
"execution_count": null,
514+
"metadata": {},
515+
"outputs": [],
516+
"source": [
517+
"print(policy)"
518+
]
519+
},
520+
{
521+
"cell_type": "code",
522+
"execution_count": null,
523+
"metadata": {},
524+
"outputs": [],
525+
"source": [
526+
"def generalized_policy_iteration(env, max_iter=2, eps=0.1, gamma=1):\n",
527+
" np.random.seed(1)\n",
528+
" states = env.observation_space\n",
529+
" actions = env.action_space\n",
530+
" policy = {s: {np.random.choice(actions): 1}\n",
531+
" for s in states}\n",
532+
" v = {s: 0 for s in states}\n",
533+
" k = 0\n",
534+
" while True:\n",
535+
" v_old = v.copy()\n",
536+
" policy = {}\n",
537+
" for s in states:\n",
538+
" policy[s], v[s] = policy_improvement(env, v, s, \n",
539+
" actions, gamma)\n",
540+
" v = policy_evaluation(env, policy, \n",
541+
" max_iter=max_iter, v=v, \n",
542+
" eps=eps, gamma=gamma)\n",
543+
" max_delta = np.amax([abs(v[s] - v_old[s]) for s in v])\n",
544+
" k += 1\n",
545+
" if max_delta < eps:\n",
546+
" print(\"GPI converged in\", k, \"iterations.\")\n",
547+
" print([abs(v[s] - v_old[s]) for s in v])\n",
548+
" break\n",
549+
" \n",
550+
" print(\"Optimal policy found!\")\n",
551+
" return policy, v"
552+
]
553+
},
554+
{
555+
"cell_type": "code",
556+
"execution_count": null,
557+
"metadata": {},
558+
"outputs": [],
559+
"source": [
560+
"policy, v = generalized_policy_iteration(foodtruck, max_iter=2, eps=0.1, gamma=1)"
561+
]
562+
},
563+
{
564+
"cell_type": "code",
565+
"execution_count": null,
566+
"metadata": {},
567+
"outputs": [],
568+
"source": [
569+
"print(\"Expected weekly profit:\", v[\"Mon\", 0])\n",
570+
"print(policy)"
571+
]
572+
},
573+
{
574+
"cell_type": "code",
575+
"execution_count": null,
576+
"metadata": {},
577+
"outputs": [],
578+
"source": [
579+
"v"
580+
]
581+
},
582+
{
583+
"cell_type": "markdown",
584+
"metadata": {},
585+
"source": [
586+
"# Monte Carlo Methods"
587+
]
588+
},
589+
{
590+
"cell_type": "markdown",
591+
"metadata": {},
592+
"source": [
593+
"## MC Prediction"
594+
]
595+
},
596+
{
597+
"cell_type": "code",
598+
"execution_count": 71,
599+
"metadata": {},
600+
"outputs": [],
601+
"source": [
602+
"def first_visit_return(returns, trajectory, gamma):\n",
603+
" G = 0\n",
604+
" T = len(trajectory) - 1\n",
605+
" for t, sar in enumerate(reversed(trajectory)):\n",
606+
" s, a, r = sar\n",
607+
" G = r + gamma * G\n",
608+
" first_visit = True\n",
609+
" for j in range(T - t):\n",
610+
" if s == trajectory[j][0]:\n",
611+
" first_visit = False\n",
612+
" if first_visit:\n",
613+
" if s in returns:\n",
614+
" returns[s].append(G)\n",
615+
" else:\n",
616+
" returns[s] = [G]\n",
617+
" return returns"
618+
]
619+
},
620+
{
621+
"cell_type": "code",
622+
"execution_count": 74,
623+
"metadata": {},
624+
"outputs": [],
625+
"source": [
626+
"def get_trajectory(env, policy):\n",
627+
" trajectory = []\n",
628+
" state = env.reset()\n",
629+
" done = False\n",
630+
" sar = [state]\n",
631+
" while not done:\n",
632+
" action = choose_action(state, policy)\n",
633+
" state, reward, done, info = env.step(action)\n",
634+
" sar.append(action)\n",
635+
" sar.append(reward)\n",
636+
" trajectory.append(sar)\n",
637+
" sar = [state]\n",
638+
" return trajectory"
639+
]
640+
},
641+
{
642+
"cell_type": "code",
643+
"execution_count": 75,
644+
"metadata": {},
645+
"outputs": [],
646+
"source": [
647+
"def first_visit_mc(env, policy, gamma, n_trajectories):\n",
648+
" np.random.seed(0)\n",
649+
" returns = {}\n",
650+
" v = {}\n",
651+
" for i in range(n_trajectories):\n",
652+
" trajectory = get_trajectory(env, policy)\n",
653+
" returns = first_visit_return(returns, \n",
654+
" trajectory, \n",
655+
" gamma)\n",
656+
" for s in env.state_space:\n",
657+
" if s in returns:\n",
658+
" v[s] = np.round(np.mean(returns[s]), 1)\n",
659+
" return v"
660+
]
661+
},
662+
{
663+
"cell_type": "code",
664+
"execution_count": 76,
665+
"metadata": {},
666+
"outputs": [],
667+
"source": [
668+
"foodtruck = FoodTruck()\n",
669+
"policy = base_policy(foodtruck.state_space)"
670+
]
671+
},
672+
{
673+
"cell_type": "code",
674+
"execution_count": 77,
675+
"metadata": {},
676+
"outputs": [
677+
{
678+
"data": {
679+
"text/plain": [
680+
"{('Mon', 0): 2515.9,\n",
681+
" ('Tue', 0): 1959.1,\n",
682+
" ('Tue', 100): 2362.2,\n",
683+
" ('Tue', 200): 2765.2,\n",
684+
" ('Wed', 0): 1411.3,\n",
685+
" ('Wed', 100): 1804.2,\n",
686+
" ('Wed', 200): 2198.9,\n",
687+
" ('Thu', 0): 852.9,\n",
688+
" ('Thu', 100): 1265.4,\n",
689+
" ('Thu', 200): 1644.4,\n",
690+
" ('Fri', 0): 301.1,\n",
691+
" ('Fri', 100): 696.5,\n",
692+
" ('Fri', 200): 1097.2}"
693+
]
694+
},
695+
"execution_count": 77,
696+
"metadata": {},
697+
"output_type": "execute_result"
698+
}
699+
],
700+
"source": [
701+
"v_est = first_visit_mc(foodtruck, policy, 1, 10000)\n",
702+
"v_est"
703+
]
704+
},
705+
{
706+
"cell_type": "code",
707+
"execution_count": 78,
708+
"metadata": {},
709+
"outputs": [
710+
{
711+
"name": "stdout",
712+
"output_type": "stream",
713+
"text": [
714+
"Converged in 6 iterations.\n"
715+
]
716+
}
717+
],
718+
"source": [
719+
"v_true = policy_evaluation(foodtruck, policy)"
720+
]
721+
},
722+
{
723+
"cell_type": "code",
724+
"execution_count": 63,
725+
"metadata": {},
726+
"outputs": [
727+
{
728+
"data": {
729+
"text/plain": [
730+
"{('Mon', 0): 2515.0,\n",
731+
" ('Tue', 0): 1960.0,\n",
732+
" ('Tue', 100): 2360.0,\n",
733+
" ('Tue', 200): 2760.0,\n",
734+
" ('Tue', 300): 3205.0,\n",
735+
" ('Wed', 0): 1405.0,\n",
736+
" ('Wed', 100): 1805.0,\n",
737+
" ('Wed', 200): 2205.0,\n",
738+
" ('Wed', 300): 2650.0,\n",
739+
" ('Thu', 0): 850.0000000000001,\n",
740+
" ('Thu', 100): 1250.0,\n",
741+
" ('Thu', 200): 1650.0,\n",
742+
" ('Thu', 300): 2095.0,\n",
743+
" ('Fri', 0): 295.00000000000006,\n",
744+
" ('Fri', 100): 695.0000000000001,\n",
745+
" ('Fri', 200): 1095.0,\n",
746+
" ('Fri', 300): 1400.0,\n",
747+
" ('Weekend', 0): 0,\n",
748+
" ('Weekend', 100): 0,\n",
749+
" ('Weekend', 200): 0,\n",
750+
" ('Weekend', 300): 0}"
751+
]
752+
},
753+
"execution_count": 63,
754+
"metadata": {},
755+
"output_type": "execute_result"
756+
}
757+
],
758+
"source": [
759+
"v_true"
760+
]
761+
},
762+
{
763+
"cell_type": "code",
764+
"execution_count": null,
765+
"metadata": {},
766+
"outputs": [],
767+
"source": [
768+
"# v_est = first_visit_mc(foodtruck, policy, 1, 5)\n",
769+
"# {s: v_est[s] for s in sorted(v_est)}"
770+
]
771+
},
772+
{
773+
"cell_type": "code",
774+
"execution_count": null,
775+
"metadata": {},
776+
"outputs": [],
777+
"source": [
778+
"# v_est = first_visit_mc(foodtruck, policy, 1, 10)\n",
779+
"# {s: v_est[s] for s in sorted(v_est)}"
780+
]
781+
},
782+
{
783+
"cell_type": "code",
784+
"execution_count": null,
785+
"metadata": {},
786+
"outputs": [],
787+
"source": [
788+
"# v_est = first_visit_mc(foodtruck, policy, 1, 100)\n",
789+
"# {s: v_est[s] for s in sorted(v_est)}"
790+
]
791+
},
792+
{
793+
"cell_type": "code",
794+
"execution_count": null,
795+
"metadata": {},
796+
"outputs": [],
797+
"source": [
798+
"# v_est = first_visit_mc(foodtruck, policy, 1, 1000)\n",
799+
"# {s: v_est[s] for s in sorted(v_est)}"
800+
]
801+
},
802+
{
803+
"cell_type": "code",
804+
"execution_count": null,
805+
"metadata": {},
806+
"outputs": [],
807+
"source": [
808+
"# v_est = first_visit_mc(foodtruck, policy, 1, 10000)\n",
809+
"# {s: v_est[s] for s in sorted(v_est)}"
810+
]
811+
},
812+
{
813+
"cell_type": "markdown",
814+
"metadata": {},
815+
"source": [
816+
"## On-policy Monte Carlo Control"
817+
]
818+
},
819+
{
820+
"cell_type": "code",
821+
"execution_count": 85,
822+
"metadata": {},
823+
"outputs": [],
824+
"source": [
825+
"import operator"
826+
]
827+
},
828+
{
829+
"cell_type": "code",
830+
"execution_count": 91,
831+
"metadata": {},
832+
"outputs": [],
833+
"source": [
834+
"def get_eps_greedy(actions, eps, a_best):\n",
835+
" prob_a = {}\n",
836+
" n_a = len(actions)\n",
837+
" for a in actions:\n",
838+
" if a == a_best:\n",
839+
" prob_a[a] = 1 - eps + eps/n_a\n",
840+
" else:\n",
841+
" prob_a[a] = eps/n_a\n",
842+
" return prob_a"
843+
]
844+
},
845+
{
846+
"cell_type": "code",
847+
"execution_count": null,
848+
"metadata": {},
849+
"outputs": [],
850+
"source": []
851+
},
852+
{
853+
"cell_type": "code",
854+
"execution_count": 92,
855+
"metadata": {},
856+
"outputs": [],
857+
"source": [
858+
"def get_random_policy(states, actions):\n",
859+
" policy = {}\n",
860+
" n_a = len(actions)\n",
861+
" for s in states:\n",
862+
" policy[s] = {a: 1/n_a for a in actions}\n",
863+
" return policy"
864+
]
865+
},
866+
{
867+
"cell_type": "code",
868+
"execution_count": 93,
869+
"metadata": {},
870+
"outputs": [],
871+
"source": [
872+
"def on_policy_first_visit_mc(env, n_iter, eps, gamma):\n",
873+
" np.random.seed(0)\n",
874+
" states = env.state_space\n",
875+
" actions = env.action_space\n",
876+
" policy = get_random_policy(states, actions)\n",
877+
" Q = {s: {a: 0 for a in actions} for s in states}\n",
878+
" Q_n = {s: {a: 0 for a in actions} for s in states}\n",
879+
" for i in range(n_iter):\n",
880+
" if i % 10000 == 0:\n",
881+
" print(\"Iteration:\", i)\n",
882+
" trajectory = get_trajectory(env, policy)\n",
883+
" G = 0\n",
884+
" T = len(trajectory) - 1\n",
885+
" for t, sar in enumerate(reversed(trajectory)):\n",
886+
" s, a, r = sar\n",
887+
" G = r + gamma * G\n",
888+
" first_visit = True\n",
889+
" for j in range(T - t):\n",
890+
" s_j = trajectory[j][0]\n",
891+
" a_j = trajectory[j][1]\n",
892+
" if (s, a) == (s_j, a_j):\n",
893+
" first_visit = False\n",
894+
" if first_visit:\n",
895+
" Q[s][a] = Q_n[s][a] * Q[s][a] + G\n",
896+
" Q_n[s][a] += 1\n",
897+
" Q[s][a] /= Q_n[s][a]\n",
898+
" a_best = max(Q[s].items(), \n",
899+
" key=operator.itemgetter(1))[0]\n",
900+
" policy[s] = get_eps_greedy(actions, \n",
901+
" eps, \n",
902+
" a_best)\n",
903+
" return policy, Q, Q_n"
904+
]
905+
},
906+
{
907+
"cell_type": "code",
908+
"execution_count": 94,
909+
"metadata": {},
910+
"outputs": [
911+
{
912+
"name": "stdout",
913+
"output_type": "stream",
914+
"text": [
915+
"Iteration: 0\n",
916+
"Iteration: 10000\n",
917+
"Iteration: 20000\n",
918+
"Iteration: 30000\n",
919+
"Iteration: 40000\n",
920+
"Iteration: 50000\n",
921+
"Iteration: 60000\n",
922+
"Iteration: 70000\n",
923+
"Iteration: 80000\n",
924+
"Iteration: 90000\n",
925+
"Iteration: 100000\n",
926+
"Iteration: 110000\n",
927+
"Iteration: 120000\n",
928+
"Iteration: 130000\n",
929+
"Iteration: 140000\n",
930+
"Iteration: 150000\n",
931+
"Iteration: 160000\n",
932+
"Iteration: 170000\n",
933+
"Iteration: 180000\n",
934+
"Iteration: 190000\n",
935+
"Iteration: 200000\n",
936+
"Iteration: 210000\n",
937+
"Iteration: 220000\n",
938+
"Iteration: 230000\n",
939+
"Iteration: 240000\n",
940+
"Iteration: 250000\n",
941+
"Iteration: 260000\n",
942+
"Iteration: 270000\n",
943+
"Iteration: 280000\n",
944+
"Iteration: 290000\n"
945+
]
946+
}
947+
],
948+
"source": [
949+
"policy, Q, Q_n = on_policy_first_visit_mc(foodtruck, \n",
950+
" 300000, \n",
951+
" 0.05, \n",
952+
" 1)"
953+
]
954+
},
955+
{
956+
"cell_type": "code",
957+
"execution_count": 90,
958+
"metadata": {},
959+
"outputs": [
960+
{
961+
"data": {
962+
"text/plain": [
963+
"{('Mon', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.96},\n",
964+
" ('Tue', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.96},\n",
965+
" ('Tue', 100): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.96, 400: 0.01},\n",
966+
" ('Tue', 200): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n",
967+
" ('Tue', 300): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n",
968+
" ('Wed', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.96},\n",
969+
" ('Wed', 100): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.96, 400: 0.01},\n",
970+
" ('Wed', 200): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n",
971+
" ('Wed', 300): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n",
972+
" ('Thu', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.96, 400: 0.01},\n",
973+
" ('Thu', 100): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n",
974+
" ('Thu', 200): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n",
975+
" ('Thu', 300): {0: 0.96, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.01},\n",
976+
" ('Fri', 0): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n",
977+
" ('Fri', 100): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n",
978+
" ('Fri', 200): {0: 0.96, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.01},\n",
979+
" ('Fri', 300): {0: 0.96, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.01},\n",
980+
" ('Weekend', 0): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2},\n",
981+
" ('Weekend', 100): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2},\n",
982+
" ('Weekend', 200): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2},\n",
983+
" ('Weekend', 300): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2}}"
984+
]
985+
},
986+
"execution_count": 90,
987+
"metadata": {},
988+
"output_type": "execute_result"
989+
}
990+
],
991+
"source": [
992+
"policy"
993+
]
994+
},
995+
{
996+
"cell_type": "code",
997+
"execution_count": 95,
998+
"metadata": {},
999+
"outputs": [
1000+
{
1001+
"data": {
1002+
"text/plain": [
1003+
"{('Mon', 0): {0: 2162.733333333329,\n",
1004+
" 100: 2468.4210526315796,\n",
1005+
" 200: 2668.7695190505888,\n",
1006+
" 300: 2739.300098231826,\n",
1007+
" 400: 2809.1632287569414},\n",
1008+
" ('Tue', 0): {0: 1539.1011235955057,\n",
1009+
" 100: 1857.630979498861,\n",
1010+
" 200: 2018.3222958057395,\n",
1011+
" 300: 2101.97486535009,\n",
1012+
" 400: 2181.249139237035},\n",
1013+
" ('Tue', 100): {0: 2243.7967115097176,\n",
1014+
" 100: 2410.7182940516295,\n",
1015+
" 200: 2537.853107344635,\n",
1016+
" 300: 2587.222441722628,\n",
1017+
" 400: 2170.4049844236765},\n",
1018+
" ('Tue', 200): {0: 2828.295819935689,\n",
1019+
" 100: 2953.6330631123433,\n",
1020+
" 200: 2996.437255166801,\n",
1021+
" 300: 2623.82297551789,\n",
1022+
" 400: 2224.710080285464},\n",
1023+
" ('Tue', 300): {0: 3383.880037488284,\n",
1024+
" 100: 3395.720002238628,\n",
1025+
" 200: 2939.4218134034168,\n",
1026+
" 300: 2572.2506393861877,\n",
1027+
" 400: 2162.3395149786},\n",
1028+
" ('Wed', 0): {0: 935.7142857142857,\n",
1029+
" 100: 1256.8720379146928,\n",
1030+
" 200: 1400.5025125628129,\n",
1031+
" 300: 1547.1040492055338,\n",
1032+
" 400: 1579.8683874265244},\n",
1033+
" ('Wed', 100): {0: 1639.7689768976904,\n",
1034+
" 100: 1868.1431005110733,\n",
1035+
" 200: 1908.107074569789,\n",
1036+
" 300: 1989.5285532259934,\n",
1037+
" 400: 1605.021520803444},\n",
1038+
" ('Wed', 200): {0: 2250.352733686064,\n",
1039+
" 100: 2341.068532900906,\n",
1040+
" 200: 2383.0059803588124,\n",
1041+
" 300: 1962.005277044855,\n",
1042+
" 400: 1573.4144222415298},\n",
1043+
" ('Wed', 300): {0: 2758.00389203214,\n",
1044+
" 100: 2778.022627490717,\n",
1045+
" 200: 2393.5081148564277,\n",
1046+
" 300: 1985.8374384236454,\n",
1047+
" 400: 1614.6220570012397},\n",
1048+
" ('Thu', 0): {0: 369.36619718309856,\n",
1049+
" 100: 684.2803030303028,\n",
1050+
" 200: 903.1539888682744,\n",
1051+
" 300: 972.1787871266652,\n",
1052+
" 400: 930.1247771836006},\n",
1053+
" ('Thu', 100): {0: 1084.478371501272,\n",
1054+
" 100: 1289.5073754522657,\n",
1055+
" 200: 1372.1298508969842,\n",
1056+
" 300: 1332.386447699365,\n",
1057+
" 400: 953.6523929471032},\n",
1058+
" ('Thu', 200): {0: 1677.668161434978,\n",
1059+
" 100: 1769.2753842946279,\n",
1060+
" 200: 1733.8299737072743,\n",
1061+
" 300: 1325.3393665158371,\n",
1062+
" 400: 919.6219621962197},\n",
1063+
" ('Thu', 300): {0: 2169.691663233083,\n",
1064+
" 100: 2166.585956416466,\n",
1065+
" 200: 1757.9545454545455,\n",
1066+
" 300: 1333.6569579288014,\n",
1067+
" 400: 953.3227848101266},\n",
1068+
" ('Fri', 0): {0: 0.0,\n",
1069+
" 100: 300.0,\n",
1070+
" 200: 388.81505831705283,\n",
1071+
" 300: 186.4516129032258,\n",
1072+
" 400: -142.74809160305333},\n",
1073+
" ('Fri', 100): {0: 700.0,\n",
1074+
" 100: 790.5049146968516,\n",
1075+
" 200: 607.3234524847425,\n",
1076+
" 300: 267.3796791443842,\n",
1077+
" 400: -110.91954022988506},\n",
1078+
" ('Fri', 200): {0: 1190.3311990960892,\n",
1079+
" 100: 988.7775551102206,\n",
1080+
" 200: 640.5092592592597,\n",
1081+
" 300: 267.4418604651163,\n",
1082+
" 400: -112.39263803680979},\n",
1083+
" ('Fri', 300): {0: 1399.4254760341432,\n",
1084+
" 100: 1152.7272727272725,\n",
1085+
" 200: 742.1875,\n",
1086+
" 300: 284.4827586206896,\n",
1087+
" 400: -120.0},\n",
1088+
" ('Weekend', 0): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
1089+
" ('Weekend', 100): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
1090+
" ('Weekend', 200): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
1091+
" ('Weekend', 300): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0}}"
1092+
]
1093+
},
1094+
"execution_count": 95,
1095+
"metadata": {},
1096+
"output_type": "execute_result"
1097+
}
1098+
],
1099+
"source": [
1100+
"Q"
1101+
]
1102+
},
1103+
{
1104+
"cell_type": "markdown",
1105+
"metadata": {},
1106+
"source": [
1107+
"## Off-policy Monte Carlo Control"
1108+
]
1109+
},
1110+
{
1111+
"cell_type": "code",
1112+
"execution_count": null,
1113+
"metadata": {},
1114+
"outputs": [],
1115+
"source": []
1116+
},
1117+
{
1118+
"cell_type": "code",
1119+
"execution_count": 108,
1120+
"metadata": {},
1121+
"outputs": [],
1122+
"source": [
1123+
"def off_policy_mc(env, n_iter, eps, gamma):\n",
1124+
" np.random.seed(0)\n",
1125+
" states = env.state_space\n",
1126+
" actions = env.action_space\n",
1127+
" Q = {s: {a: 0 for a in actions} for s in states}\n",
1128+
" C = {s: {a: 0 for a in actions} for s in states}\n",
1129+
" target_policy = {}\n",
1130+
" behavior_policy = get_random_policy(states, \n",
1131+
" actions)\n",
1132+
" for i in range(n_iter):\n",
1133+
" if i % 10000 == 0:\n",
1134+
" print(\"Iteration:\", i)\n",
1135+
" trajectory = get_trajectory(env, \n",
1136+
" behavior_policy)\n",
1137+
" G = 0\n",
1138+
" W = 1\n",
1139+
" T = len(trajectory) - 1\n",
1140+
" for t, sar in enumerate(reversed(trajectory)):\n",
1141+
" s, a, r = sar\n",
1142+
" G = r + gamma * G\n",
1143+
" C[s][a] += W\n",
1144+
" Q[s][a] += (W/C[s][a]) * (G - Q[s][a])\n",
1145+
" a_best = max(Q[s].items(), \n",
1146+
" key=operator.itemgetter(1))[0]\n",
1147+
" target_policy[s] = a_best\n",
1148+
" behavior_policy[s] = get_eps_greedy(actions, \n",
1149+
" eps, \n",
1150+
" a_best)\n",
1151+
" if a != target_policy[s]:\n",
1152+
" break\n",
1153+
" W = W / behavior_policy[s][a]\n",
1154+
" target_policy = {s: target_policy[s] for s in states\n",
1155+
" if s in target_policy}\n",
1156+
" return target_policy, Q"
1157+
]
1158+
},
1159+
{
1160+
"cell_type": "code",
1161+
"execution_count": 109,
1162+
"metadata": {},
1163+
"outputs": [
1164+
{
1165+
"name": "stdout",
1166+
"output_type": "stream",
1167+
"text": [
1168+
"Iteration: 0\n",
1169+
"Iteration: 10000\n",
1170+
"Iteration: 20000\n",
1171+
"Iteration: 30000\n",
1172+
"Iteration: 40000\n",
1173+
"Iteration: 50000\n",
1174+
"Iteration: 60000\n",
1175+
"Iteration: 70000\n",
1176+
"Iteration: 80000\n",
1177+
"Iteration: 90000\n",
1178+
"Iteration: 100000\n",
1179+
"Iteration: 110000\n",
1180+
"Iteration: 120000\n",
1181+
"Iteration: 130000\n",
1182+
"Iteration: 140000\n",
1183+
"Iteration: 150000\n",
1184+
"Iteration: 160000\n",
1185+
"Iteration: 170000\n",
1186+
"Iteration: 180000\n",
1187+
"Iteration: 190000\n",
1188+
"Iteration: 200000\n",
1189+
"Iteration: 210000\n",
1190+
"Iteration: 220000\n",
1191+
"Iteration: 230000\n",
1192+
"Iteration: 240000\n",
1193+
"Iteration: 250000\n",
1194+
"Iteration: 260000\n",
1195+
"Iteration: 270000\n",
1196+
"Iteration: 280000\n",
1197+
"Iteration: 290000\n"
1198+
]
1199+
}
1200+
],
1201+
"source": [
1202+
"policy, Q = off_policy_mc(foodtruck, 300000, 0.05, 1)"
1203+
]
1204+
},
1205+
{
1206+
"cell_type": "code",
1207+
"execution_count": 110,
1208+
"metadata": {},
1209+
"outputs": [
1210+
{
1211+
"data": {
1212+
"text/plain": [
1213+
"{('Mon', 0): 400,\n",
1214+
" ('Tue', 0): 400,\n",
1215+
" ('Tue', 100): 300,\n",
1216+
" ('Tue', 200): 200,\n",
1217+
" ('Tue', 300): 100,\n",
1218+
" ('Wed', 0): 400,\n",
1219+
" ('Wed', 100): 300,\n",
1220+
" ('Wed', 200): 200,\n",
1221+
" ('Wed', 300): 100,\n",
1222+
" ('Thu', 0): 300,\n",
1223+
" ('Thu', 100): 200,\n",
1224+
" ('Thu', 200): 100,\n",
1225+
" ('Thu', 300): 0,\n",
1226+
" ('Fri', 0): 200,\n",
1227+
" ('Fri', 100): 100,\n",
1228+
" ('Fri', 200): 0,\n",
1229+
" ('Fri', 300): 0}"
1230+
]
1231+
},
1232+
"execution_count": 110,
1233+
"metadata": {},
1234+
"output_type": "execute_result"
1235+
}
1236+
],
1237+
"source": [
1238+
"policy"
1239+
]
1240+
},
1241+
{
1242+
"cell_type": "code",
1243+
"execution_count": 111,
1244+
"metadata": {},
1245+
"outputs": [
1246+
{
1247+
"data": {
1248+
"text/plain": [
1249+
"{('Mon', 0): {0: 2232.674050632915,\n",
1250+
" 100: 2539.364696421396,\n",
1251+
" 200: 2725.681570338065,\n",
1252+
" 300: 2822.8136882129284,\n",
1253+
" 400: 2878.458190025779},\n",
1254+
" ('Tue', 0): {0: 1594.8051948051952,\n",
1255+
" 100: 1928.976034858388,\n",
1256+
" 200: 2067.4576271186465,\n",
1257+
" 300: 2207.8512396694205,\n",
1258+
" 400: 2239.8886329583893},\n",
1259+
" ('Tue', 100): {0: 2318.9435336976317,\n",
1260+
" 100: 2536.8012422360302,\n",
1261+
" 200: 2549.486301369862,\n",
1262+
" 300: 2650.193090274893,\n",
1263+
" 400: 2256.120527306967},\n",
1264+
" ('Tue', 200): {0: 2922.175290390706,\n",
1265+
" 100: 3012.8990770161868,\n",
1266+
" 200: 3052.4769607403373,\n",
1267+
" 300: 2689.515219842163,\n",
1268+
" 400: 2293.305439330548},\n",
1269+
" ('Tue', 300): {0: 3420.032031538755,\n",
1270+
" 100: 3453.749726573689,\n",
1271+
" 200: 3014.1210374639763,\n",
1272+
" 300: 2635.802469135803,\n",
1273+
" 400: 2233.3333333333344},\n",
1274+
" ('Wed', 0): {0: 927.9702970297026,\n",
1275+
" 100: 1303.1026252983302,\n",
1276+
" 200: 1428.831168831168,\n",
1277+
" 300: 1566.1498708010329,\n",
1278+
" 400: 1616.5133331502423},\n",
1279+
" ('Wed', 100): {0: 1683.8652482269495,\n",
1280+
" 100: 1896.0360360360366,\n",
1281+
" 200: 1976.8450184501858,\n",
1282+
" 300: 2024.3386976631361,\n",
1283+
" 400: 1650.87440381558},\n",
1284+
" ('Wed', 200): {0: 2277.8664007976076,\n",
1285+
" 100: 2405.7504873294333,\n",
1286+
" 200: 2419.006699098848,\n",
1287+
" 300: 2000.3857280617174,\n",
1288+
" 400: 1608.8068181818178},\n",
1289+
" ('Wed', 300): {0: 2779.4180573384715,\n",
1290+
" 100: 2818.4754229486366,\n",
1291+
" 200: 2422.7878787878817,\n",
1292+
" 300: 2017.7989130434773,\n",
1293+
" 400: 1660.6602475928496},\n",
1294+
" ('Thu', 0): {0: 369.164265129683,\n",
1295+
" 100: 684.9275362318838,\n",
1296+
" 200: 912.9056047197645,\n",
1297+
" 300: 988.2722582352171,\n",
1298+
" 400: 926.3157894736842},\n",
1299+
" ('Thu', 100): {0: 1090.7738095238096,\n",
1300+
" 100: 1329.5566502463064,\n",
1301+
" 200: 1392.0507055220148,\n",
1302+
" 300: 1373.6577181208052,\n",
1303+
" 400: 961.7241379310348},\n",
1304+
" ('Thu', 200): {0: 1699.1087344028504,\n",
1305+
" 100: 1789.1752957897363,\n",
1306+
" 200: 1760.222222222221,\n",
1307+
" 300: 1342.7149321266952,\n",
1308+
" 400: 965.2557319223995},\n",
1309+
" ('Thu', 300): {0: 2190.6271182185546,\n",
1310+
" 100: 2176.451612903226,\n",
1311+
" 200: 1780.9290953545235,\n",
1312+
" 300: 1360.7017543859658,\n",
1313+
" 400: 964.203233256352},\n",
1314+
" ('Fri', 0): {0: 0.0,\n",
1315+
" 100: 300.0,\n",
1316+
" 200: 388.8413403310466,\n",
1317+
" 300: 189.6405919661735,\n",
1318+
" 400: -146.61016949152557},\n",
1319+
" ('Fri', 100): {0: 700.0,\n",
1320+
" 100: 790.0458861880747,\n",
1321+
" 200: 608.920985556499,\n",
1322+
" 300: 265.5866900175128,\n",
1323+
" 400: -103.34967320261451},\n",
1324+
" ('Fri', 200): {0: 1190.388245916431,\n",
1325+
" 100: 1009.6551724137929,\n",
1326+
" 200: 651.612903225807,\n",
1327+
" 300: 266.99669966996686,\n",
1328+
" 400: -116.64641555285537},\n",
1329+
" ('Fri', 300): {0: 1404.084014002334,\n",
1330+
" 100: 1116.6666666666667,\n",
1331+
" 200: 702.9411764705883,\n",
1332+
" 300: 282.3529411764706,\n",
1333+
" 400: -175.86206896551724},\n",
1334+
" ('Weekend', 0): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
1335+
" ('Weekend', 100): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
1336+
" ('Weekend', 200): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
1337+
" ('Weekend', 300): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0}}"
1338+
]
1339+
},
1340+
"execution_count": 111,
1341+
"metadata": {},
1342+
"output_type": "execute_result"
1343+
}
1344+
],
1345+
"source": [
1346+
"Q"
1347+
]
1348+
},
1349+
{
1350+
"cell_type": "code",
1351+
"execution_count": null,
1352+
"metadata": {},
1353+
"outputs": [],
1354+
"source": []
1355+
},
1356+
{
1357+
"cell_type": "code",
1358+
"execution_count": null,
1359+
"metadata": {},
1360+
"outputs": [],
1361+
"source": []
1362+
},
1363+
{
1364+
"cell_type": "markdown",
1365+
"metadata": {},
1366+
"source": [
1367+
"# TD Learning"
1368+
]
1369+
},
1370+
{
1371+
"cell_type": "markdown",
1372+
"metadata": {},
1373+
"source": [
1374+
"## TD Prediction"
1375+
]
1376+
},
1377+
{
1378+
"cell_type": "code",
1379+
"execution_count": 116,
1380+
"metadata": {},
1381+
"outputs": [],
1382+
"source": [
1383+
"def one_step_td_prediction(env, policy, gamma, alpha, n_iter):\n",
1384+
" np.random.seed(0)\n",
1385+
" states = env.state_space\n",
1386+
" v = {s: 0 for s in states}\n",
1387+
" s = env.reset()\n",
1388+
" for i in range(n_iter):\n",
1389+
" a = choose_action(s, policy)\n",
1390+
" s_next, reward, done, info = env.step(a)\n",
1391+
" v[s] += alpha * (reward + gamma * v[s_next] - v[s])\n",
1392+
" if done:\n",
1393+
" s = env.reset()\n",
1394+
" else:\n",
1395+
" s = s_next\n",
1396+
" return v"
1397+
]
1398+
},
1399+
{
1400+
"cell_type": "code",
1401+
"execution_count": null,
1402+
"metadata": {},
1403+
"outputs": [],
1404+
"source": []
1405+
},
1406+
{
1407+
"cell_type": "code",
1408+
"execution_count": 117,
1409+
"metadata": {},
1410+
"outputs": [
1411+
{
1412+
"data": {
1413+
"text/plain": [
1414+
"{('Mon', 0): 2506.576417395407,\n",
1415+
" ('Tue', 0): 1956.077876400167,\n",
1416+
" ('Tue', 100): 2368.7400039407535,\n",
1417+
" ('Tue', 200): 2767.5069659225423,\n",
1418+
" ('Tue', 300): 0,\n",
1419+
" ('Wed', 0): 1413.0055559001296,\n",
1420+
" ('Wed', 100): 1813.546186490315,\n",
1421+
" ('Wed', 200): 2200.8873259700867,\n",
1422+
" ('Wed', 300): 0,\n",
1423+
" ('Thu', 0): 828.2915189850011,\n",
1424+
" ('Thu', 100): 1280.424626614422,\n",
1425+
" ('Thu', 200): 1675.8661846955831,\n",
1426+
" ('Thu', 300): 0,\n",
1427+
" ('Fri', 0): 345.52991944823583,\n",
1428+
" ('Fri', 100): 677.4358179389413,\n",
1429+
" ('Fri', 200): 1094.8263154150825,\n",
1430+
" ('Fri', 300): 0,\n",
1431+
" ('Weekend', 0): 0,\n",
1432+
" ('Weekend', 100): 0,\n",
1433+
" ('Weekend', 200): 0,\n",
1434+
" ('Weekend', 300): 0}"
1435+
]
1436+
},
1437+
"execution_count": 117,
1438+
"metadata": {},
1439+
"output_type": "execute_result"
1440+
}
1441+
],
1442+
"source": [
1443+
"policy = base_policy(foodtruck.state_space)\n",
1444+
"v = one_step_td_prediction(foodtruck, policy, 1, 0.01, 100000)\n",
1445+
"v"
1446+
]
1447+
},
1448+
{
1449+
"cell_type": "code",
1450+
"execution_count": null,
1451+
"metadata": {},
1452+
"outputs": [],
1453+
"source": [
1454+
"print({s: np.round(v[s]) for s in v})"
1455+
]
1456+
},
1457+
{
1458+
"cell_type": "markdown",
1459+
"metadata": {},
1460+
"source": [
1461+
"True values\n",
1462+
"{('Mon', 0): 2515.0,\n",
1463+
" ('Tue', 0): 1960.0,\n",
1464+
" ('Tue', 100): 2360.0,\n",
1465+
" ('Tue', 200): 2760.0,\n",
1466+
" ('Tue', 300): 3205.0,\n",
1467+
" ('Wed', 0): 1405.0,\n",
1468+
" ('Wed', 100): 1805.0,\n",
1469+
" ('Wed', 200): 2205.0,\n",
1470+
" ('Wed', 300): 2650.0,\n",
1471+
" ('Thu', 0): 850.0000000000001,\n",
1472+
" ('Thu', 100): 1250.0,\n",
1473+
" ('Thu', 200): 1650.0,\n",
1474+
" ('Thu', 300): 2095.0,\n",
1475+
" ('Fri', 0): 295.00000000000006,\n",
1476+
" ('Fri', 100): 695.0000000000001,\n",
1477+
" ('Fri', 200): 1095.0,\n",
1478+
" ('Fri', 300): 1400.0,\n",
1479+
" ('Weekend', 0): 0,\n",
1480+
" ('Weekend', 100): 0,\n",
1481+
" ('Weekend', 200): 0,\n",
1482+
" ('Weekend', 300): 0}"
1483+
]
1484+
},
1485+
{
1486+
"cell_type": "code",
1487+
"execution_count": 118,
1488+
"metadata": {},
1489+
"outputs": [],
1490+
"source": [
1491+
"def sarsa(env, gamma, eps, alpha, n_iter):\n",
1492+
" np.random.seed(0)\n",
1493+
" states = env.state_space\n",
1494+
" actions = env.action_space\n",
1495+
" Q = {s: {a: 0 for a in actions} for s in states}\n",
1496+
" policy = get_random_policy(states, actions)\n",
1497+
" s = env.reset()\n",
1498+
" a = choose_action(s, policy)\n",
1499+
" for i in range(n_iter):\n",
1500+
" if i % 100000 == 0:\n",
1501+
" print(\"Iteration:\", i)\n",
1502+
" s_next, reward, done, info = env.step(a)\n",
1503+
" a_best = max(Q[s_next].items(), \n",
1504+
" key=operator.itemgetter(1))[0]\n",
1505+
" policy[s_next] = get_eps_greedy(actions, eps, a_best)\n",
1506+
" a_next = choose_action(s_next, policy)\n",
1507+
" Q[s][a] += alpha * (reward \n",
1508+
" + gamma * Q[s_next][a_next] \n",
1509+
" - Q[s][a])\n",
1510+
" if done:\n",
1511+
" s = env.reset()\n",
1512+
" a_best = max(Q[s].items(), \n",
1513+
" key=operator.itemgetter(1))[0]\n",
1514+
" policy[s] = get_eps_greedy(actions, eps, a_best)\n",
1515+
" a = choose_action(s, policy)\n",
1516+
" else:\n",
1517+
" s = s_next\n",
1518+
" a = a_next\n",
1519+
" return policy, Q"
1520+
]
1521+
},
1522+
{
1523+
"cell_type": "code",
1524+
"execution_count": 119,
1525+
"metadata": {},
1526+
"outputs": [
1527+
{
1528+
"name": "stdout",
1529+
"output_type": "stream",
1530+
"text": [
1531+
"Iteration: 0\n",
1532+
"Iteration: 100000\n",
1533+
"Iteration: 200000\n",
1534+
"Iteration: 300000\n",
1535+
"Iteration: 400000\n",
1536+
"Iteration: 500000\n",
1537+
"Iteration: 600000\n",
1538+
"Iteration: 700000\n",
1539+
"Iteration: 800000\n",
1540+
"Iteration: 900000\n"
1541+
]
1542+
}
1543+
],
1544+
"source": [
1545+
"policy, Q = sarsa(foodtruck, 1, 0.1, 0.01, 1000000)"
1546+
]
1547+
},
1548+
{
1549+
"cell_type": "code",
1550+
"execution_count": 120,
1551+
"metadata": {},
1552+
"outputs": [
1553+
{
1554+
"data": {
1555+
"text/plain": [
1556+
"{('Mon', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n",
1557+
" ('Tue', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n",
1558+
" ('Tue', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n",
1559+
" ('Tue', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1560+
" ('Tue', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1561+
" ('Wed', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n",
1562+
" ('Wed', 100): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n",
1563+
" ('Wed', 200): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n",
1564+
" ('Wed', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1565+
" ('Thu', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n",
1566+
" ('Thu', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n",
1567+
" ('Thu', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1568+
" ('Thu', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1569+
" ('Fri', 0): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n",
1570+
" ('Fri', 100): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1571+
" ('Fri', 200): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1572+
" ('Fri', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1573+
" ('Weekend', 0): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1574+
" ('Weekend', 100): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1575+
" ('Weekend', 200): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n",
1576+
" ('Weekend', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02}}"
1577+
]
1578+
},
1579+
"execution_count": 120,
1580+
"metadata": {},
1581+
"output_type": "execute_result"
1582+
}
1583+
],
1584+
"source": [
1585+
"policy"
1586+
]
1587+
},
1588+
{
1589+
"cell_type": "code",
1590+
"execution_count": 121,
1591+
"metadata": {},
1592+
"outputs": [
1593+
{
1594+
"data": {
1595+
"text/plain": [
1596+
"{0: 2099.8661156763687,\n",
1597+
" 100: 2399.8190742726747,\n",
1598+
" 200: 2604.6629056622382,\n",
1599+
" 300: 2670.098987213351,\n",
1600+
" 400: 2632.8387133517112}"
1601+
]
1602+
},
1603+
"execution_count": 121,
1604+
"metadata": {},
1605+
"output_type": "execute_result"
1606+
}
1607+
],
1608+
"source": [
1609+
"Q[('Mon', 0)]"
1610+
]
1611+
},
1612+
{
1613+
"cell_type": "code",
1614+
"execution_count": null,
1615+
"metadata": {},
1616+
"outputs": [],
1617+
"source": []
1618+
},
1619+
{
1620+
"cell_type": "code",
1621+
"execution_count": null,
1622+
"metadata": {},
1623+
"outputs": [],
1624+
"source": []
1625+
},
1626+
{
1627+
"cell_type": "code",
1628+
"execution_count": 122,
1629+
"metadata": {},
1630+
"outputs": [],
1631+
"source": [
1632+
"def q_learning(env, gamma, eps, alpha, n_iter):\n",
1633+
" np.random.seed(0)\n",
1634+
" states = env.state_space\n",
1635+
" actions = env.action_space\n",
1636+
" Q = {s: {a: 0 for a in actions} for s in states}\n",
1637+
" policy = get_random_policy(states, actions)\n",
1638+
" s = env.reset()\n",
1639+
" for i in range(n_iter):\n",
1640+
" if i % 100000 == 0:\n",
1641+
" print(\"Iteration:\", i)\n",
1642+
" a_best = max(Q[s].items(), \n",
1643+
" key=operator.itemgetter(1))[0]\n",
1644+
" policy[s] = get_eps_greedy(actions, eps, a_best)\n",
1645+
" a = choose_action(s, policy)\n",
1646+
" s_next, reward, done, info = env.step(a)\n",
1647+
" Q[s][a] += alpha * (reward \n",
1648+
" + gamma * max(Q[s_next].values()) \n",
1649+
" - Q[s][a])\n",
1650+
" if done:\n",
1651+
" s = env.reset()\n",
1652+
" else:\n",
1653+
" s = s_next\n",
1654+
" policy = {s: {max(policy[s].items(), \n",
1655+
" key=operator.itemgetter(1))[0]: 1}\n",
1656+
" for s in states}\n",
1657+
" return policy, Q"
1658+
]
1659+
},
1660+
{
1661+
"cell_type": "code",
1662+
"execution_count": 123,
1663+
"metadata": {},
1664+
"outputs": [
1665+
{
1666+
"name": "stdout",
1667+
"output_type": "stream",
1668+
"text": [
1669+
"Iteration: 0\n",
1670+
"Iteration: 100000\n",
1671+
"Iteration: 200000\n",
1672+
"Iteration: 300000\n",
1673+
"Iteration: 400000\n",
1674+
"Iteration: 500000\n",
1675+
"Iteration: 600000\n",
1676+
"Iteration: 700000\n",
1677+
"Iteration: 800000\n",
1678+
"Iteration: 900000\n"
1679+
]
1680+
},
1681+
{
1682+
"data": {
1683+
"text/plain": [
1684+
"{('Mon', 0): {400: 1},\n",
1685+
" ('Tue', 0): {400: 1},\n",
1686+
" ('Tue', 100): {300: 1},\n",
1687+
" ('Tue', 200): {200: 1},\n",
1688+
" ('Tue', 300): {100: 1},\n",
1689+
" ('Wed', 0): {400: 1},\n",
1690+
" ('Wed', 100): {300: 1},\n",
1691+
" ('Wed', 200): {200: 1},\n",
1692+
" ('Wed', 300): {100: 1},\n",
1693+
" ('Thu', 0): {300: 1},\n",
1694+
" ('Thu', 100): {200: 1},\n",
1695+
" ('Thu', 200): {100: 1},\n",
1696+
" ('Thu', 300): {0: 1},\n",
1697+
" ('Fri', 0): {200: 1},\n",
1698+
" ('Fri', 100): {100: 1},\n",
1699+
" ('Fri', 200): {0: 1},\n",
1700+
" ('Fri', 300): {0: 1},\n",
1701+
" ('Weekend', 0): {0: 1},\n",
1702+
" ('Weekend', 100): {0: 1},\n",
1703+
" ('Weekend', 200): {0: 1},\n",
1704+
" ('Weekend', 300): {0: 1}}"
1705+
]
1706+
},
1707+
"execution_count": 123,
1708+
"metadata": {},
1709+
"output_type": "execute_result"
1710+
}
1711+
],
1712+
"source": [
1713+
"policy, Q = q_learning(foodtruck, 1, 0.1, 0.01, 1000000)\n",
1714+
"policy"
1715+
]
1716+
},
1717+
{
1718+
"cell_type": "code",
1719+
"execution_count": 124,
1720+
"metadata": {},
1721+
"outputs": [
1722+
{
1723+
"name": "stdout",
1724+
"output_type": "stream",
1725+
"text": [
1726+
"Iteration: 0\n",
1727+
"Iteration: 100000\n",
1728+
"Iteration: 200000\n",
1729+
"Iteration: 300000\n",
1730+
"Iteration: 400000\n",
1731+
"Iteration: 500000\n",
1732+
"Iteration: 600000\n",
1733+
"Iteration: 700000\n",
1734+
"Iteration: 800000\n",
1735+
"Iteration: 900000\n",
1736+
"Iteration: 1000000\n",
1737+
"Iteration: 1100000\n",
1738+
"Iteration: 1200000\n",
1739+
"Iteration: 1300000\n",
1740+
"Iteration: 1400000\n",
1741+
"Iteration: 1500000\n",
1742+
"Iteration: 1600000\n",
1743+
"Iteration: 1700000\n",
1744+
"Iteration: 1800000\n",
1745+
"Iteration: 1900000\n",
1746+
"Iteration: 2000000\n",
1747+
"Iteration: 2100000\n",
1748+
"Iteration: 2200000\n",
1749+
"Iteration: 2300000\n",
1750+
"Iteration: 2400000\n",
1751+
"Iteration: 2500000\n",
1752+
"Iteration: 2600000\n",
1753+
"Iteration: 2700000\n",
1754+
"Iteration: 2800000\n",
1755+
"Iteration: 2900000\n",
1756+
"Iteration: 3000000\n",
1757+
"Iteration: 3100000\n",
1758+
"Iteration: 3200000\n",
1759+
"Iteration: 3300000\n",
1760+
"Iteration: 3400000\n",
1761+
"Iteration: 3500000\n",
1762+
"Iteration: 3600000\n",
1763+
"Iteration: 3700000\n",
1764+
"Iteration: 3800000\n",
1765+
"Iteration: 3900000\n",
1766+
"Iteration: 4000000\n",
1767+
"Iteration: 4100000\n",
1768+
"Iteration: 4200000\n",
1769+
"Iteration: 4300000\n",
1770+
"Iteration: 4400000\n",
1771+
"Iteration: 4500000\n",
1772+
"Iteration: 4600000\n",
1773+
"Iteration: 4700000\n",
1774+
"Iteration: 4800000\n",
1775+
"Iteration: 4900000\n",
1776+
"Iteration: 5000000\n",
1777+
"Iteration: 5100000\n",
1778+
"Iteration: 5200000\n",
1779+
"Iteration: 5300000\n",
1780+
"Iteration: 5400000\n",
1781+
"Iteration: 5500000\n",
1782+
"Iteration: 5600000\n",
1783+
"Iteration: 5700000\n",
1784+
"Iteration: 5800000\n",
1785+
"Iteration: 5900000\n",
1786+
"Iteration: 6000000\n",
1787+
"Iteration: 6100000\n",
1788+
"Iteration: 6200000\n",
1789+
"Iteration: 6300000\n",
1790+
"Iteration: 6400000\n",
1791+
"Iteration: 6500000\n",
1792+
"Iteration: 6600000\n",
1793+
"Iteration: 6700000\n",
1794+
"Iteration: 6800000\n",
1795+
"Iteration: 6900000\n",
1796+
"Iteration: 7000000\n",
1797+
"Iteration: 7100000\n",
1798+
"Iteration: 7200000\n",
1799+
"Iteration: 7300000\n",
1800+
"Iteration: 7400000\n",
1801+
"Iteration: 7500000\n",
1802+
"Iteration: 7600000\n",
1803+
"Iteration: 7700000\n",
1804+
"Iteration: 7800000\n",
1805+
"Iteration: 7900000\n",
1806+
"Iteration: 8000000\n",
1807+
"Iteration: 8100000\n",
1808+
"Iteration: 8200000\n",
1809+
"Iteration: 8300000\n",
1810+
"Iteration: 8400000\n",
1811+
"Iteration: 8500000\n",
1812+
"Iteration: 8600000\n",
1813+
"Iteration: 8700000\n",
1814+
"Iteration: 8800000\n",
1815+
"Iteration: 8900000\n",
1816+
"Iteration: 9000000\n",
1817+
"Iteration: 9100000\n",
1818+
"Iteration: 9200000\n",
1819+
"Iteration: 9300000\n",
1820+
"Iteration: 9400000\n",
1821+
"Iteration: 9500000\n",
1822+
"Iteration: 9600000\n",
1823+
"Iteration: 9700000\n",
1824+
"Iteration: 9800000\n",
1825+
"Iteration: 9900000\n",
1826+
"Iteration: 10000000\n",
1827+
"Iteration: 10100000\n",
1828+
"Iteration: 10200000\n",
1829+
"Iteration: 10300000\n",
1830+
"Iteration: 10400000\n",
1831+
"Iteration: 10500000\n",
1832+
"Iteration: 10600000\n",
1833+
"Iteration: 10700000\n",
1834+
"Iteration: 10800000\n",
1835+
"Iteration: 10900000\n",
1836+
"Iteration: 11000000\n",
1837+
"Iteration: 11100000\n",
1838+
"Iteration: 11200000\n",
1839+
"Iteration: 11300000\n",
1840+
"Iteration: 11400000\n",
1841+
"Iteration: 11500000\n",
1842+
"Iteration: 11600000\n",
1843+
"Iteration: 11700000\n",
1844+
"Iteration: 11800000\n",
1845+
"Iteration: 11900000\n",
1846+
"Iteration: 12000000\n",
1847+
"Iteration: 12100000\n",
1848+
"Iteration: 12200000\n",
1849+
"Iteration: 12300000\n",
1850+
"Iteration: 12400000\n",
1851+
"Iteration: 12500000\n",
1852+
"Iteration: 12600000\n",
1853+
"Iteration: 12700000\n",
1854+
"Iteration: 12800000\n",
1855+
"Iteration: 12900000\n",
1856+
"Iteration: 13000000\n",
1857+
"Iteration: 13100000\n",
1858+
"Iteration: 13200000\n",
1859+
"Iteration: 13300000\n",
1860+
"Iteration: 13400000\n",
1861+
"Iteration: 13500000\n",
1862+
"Iteration: 13600000\n",
1863+
"Iteration: 13700000\n",
1864+
"Iteration: 13800000\n",
1865+
"Iteration: 13900000\n",
1866+
"Iteration: 14000000\n",
1867+
"Iteration: 14100000\n",
1868+
"Iteration: 14200000\n",
1869+
"Iteration: 14300000\n",
1870+
"Iteration: 14400000\n",
1871+
"Iteration: 14500000\n",
1872+
"Iteration: 14600000\n",
1873+
"Iteration: 14700000\n",
1874+
"Iteration: 14800000\n",
1875+
"Iteration: 14900000\n",
1876+
"Iteration: 15000000\n",
1877+
"Iteration: 15100000\n",
1878+
"Iteration: 15200000\n",
1879+
"Iteration: 15300000\n",
1880+
"Iteration: 15400000\n",
1881+
"Iteration: 15500000\n",
1882+
"Iteration: 15600000\n",
1883+
"Iteration: 15700000\n",
1884+
"Iteration: 15800000\n",
1885+
"Iteration: 15900000\n",
1886+
"Iteration: 16000000\n",
1887+
"Iteration: 16100000\n",
1888+
"Iteration: 16200000\n",
1889+
"Iteration: 16300000\n",
1890+
"Iteration: 16400000\n",
1891+
"Iteration: 16500000\n",
1892+
"Iteration: 16600000\n",
1893+
"Iteration: 16700000\n",
1894+
"Iteration: 16800000\n",
1895+
"Iteration: 16900000\n",
1896+
"Iteration: 17000000\n",
1897+
"Iteration: 17100000\n",
1898+
"Iteration: 17200000\n",
1899+
"Iteration: 17300000\n",
1900+
"Iteration: 17400000\n",
1901+
"Iteration: 17500000\n",
1902+
"Iteration: 17600000\n",
1903+
"Iteration: 17700000\n",
1904+
"Iteration: 17800000\n",
1905+
"Iteration: 17900000\n",
1906+
"Iteration: 18000000\n",
1907+
"Iteration: 18100000\n",
1908+
"Iteration: 18200000\n",
1909+
"Iteration: 18300000\n",
1910+
"Iteration: 18400000\n",
1911+
"Iteration: 18500000\n",
1912+
"Iteration: 18600000\n",
1913+
"Iteration: 18700000\n",
1914+
"Iteration: 18800000\n",
1915+
"Iteration: 18900000\n",
1916+
"Iteration: 19000000\n",
1917+
"Iteration: 19100000\n",
1918+
"Iteration: 19200000\n",
1919+
"Iteration: 19300000\n",
1920+
"Iteration: 19400000\n",
1921+
"Iteration: 19500000\n",
1922+
"Iteration: 19600000\n",
1923+
"Iteration: 19700000\n",
1924+
"Iteration: 19800000\n",
1925+
"Iteration: 19900000\n"
1926+
]
1927+
},
1928+
{
1929+
"data": {
1930+
"text/plain": [
1931+
"({('Mon', 0): {400: 1},\n",
1932+
" ('Tue', 0): {300: 1},\n",
1933+
" ('Tue', 100): {300: 1},\n",
1934+
" ('Tue', 200): {200: 1},\n",
1935+
" ('Tue', 300): {100: 1},\n",
1936+
" ('Wed', 0): {400: 1},\n",
1937+
" ('Wed', 100): {300: 1},\n",
1938+
" ('Wed', 200): {200: 1},\n",
1939+
" ('Wed', 300): {100: 1},\n",
1940+
" ('Thu', 0): {300: 1},\n",
1941+
" ('Thu', 100): {200: 1},\n",
1942+
" ('Thu', 200): {100: 1},\n",
1943+
" ('Thu', 300): {0: 1},\n",
1944+
" ('Fri', 0): {200: 1},\n",
1945+
" ('Fri', 100): {100: 1},\n",
1946+
" ('Fri', 200): {0: 1},\n",
1947+
" ('Fri', 300): {0: 1},\n",
1948+
" ('Weekend', 0): {0: 1},\n",
1949+
" ('Weekend', 100): {0: 1},\n",
1950+
" ('Weekend', 200): {0: 1},\n",
1951+
" ('Weekend', 300): {0: 1}},\n",
1952+
" {('Mon', 0): {0: 2225.749496682385,\n",
1953+
" 100: 2528.178263359892,\n",
1954+
" 200: 2752.245336408776,\n",
1955+
" 300: 2833.598086662411,\n",
1956+
" 400: 2865.5080973287336},\n",
1957+
" ('Tue', 0): {0: 1627.1674196319675,\n",
1958+
" 100: 1926.185189822399,\n",
1959+
" 200: 2130.63971600556,\n",
1960+
" 300: 2235.794644930646,\n",
1961+
" 400: 2202.700921685597},\n",
1962+
" ('Tue', 100): {0: 2323.642847941712,\n",
1963+
" 100: 2546.146008882256,\n",
1964+
" 200: 2622.2014944709003,\n",
1965+
" 300: 2704.7958165719538,\n",
1966+
" 400: 2254.9865435917945},\n",
1967+
" ('Tue', 200): {0: 2938.47212708529,\n",
1968+
" 100: 2985.6763069672907,\n",
1969+
" 200: 3045.55602709444,\n",
1970+
" 300: 2660.793750889116,\n",
1971+
" 400: 2244.116679273476},\n",
1972+
" ('Tue', 300): {0: 3397.5493842636856,\n",
1973+
" 100: 3431.1584693328227,\n",
1974+
" 200: 3047.963831661575,\n",
1975+
" 300: 2689.0905262507554,\n",
1976+
" 400: 2246.0842993310807},\n",
1977+
" ('Wed', 0): {0: 991.8337704155527,\n",
1978+
" 100: 1294.3979155570473,\n",
1979+
" 200: 1499.2384682910836,\n",
1980+
" 300: 1560.5737610953374,\n",
1981+
" 400: 1656.9354742311311},\n",
1982+
" ('Wed', 100): {0: 1693.4281528809047,\n",
1983+
" 100: 1890.5014859779849,\n",
1984+
" 200: 1967.5195056845337,\n",
1985+
" 300: 2030.2875396109434,\n",
1986+
" 400: 1624.7053132984788},\n",
1987+
" ('Wed', 200): {0: 2307.350730239611,\n",
1988+
" 100: 2368.121947542663,\n",
1989+
" 200: 2439.4451003135055,\n",
1990+
" 300: 2028.5567501077871,\n",
1991+
" 400: 1602.4490693607893},\n",
1992+
" ('Wed', 300): {0: 2766.5347359553866,\n",
1993+
" 100: 2817.6576462084945,\n",
1994+
" 200: 2397.5480427541106,\n",
1995+
" 300: 2028.9023931048234,\n",
1996+
" 400: 1610.7042344178608},\n",
1997+
" ('Thu', 0): {0: 388.45762020693445,\n",
1998+
" 100: 689.9741046142422,\n",
1999+
" 200: 886.8292737374425,\n",
2000+
" 300: 1008.8462346115972,\n",
2001+
" 400: 970.599703355806},\n",
2002+
" ('Thu', 100): {0: 1086.9408520148895,\n",
2003+
" 100: 1301.332777514599,\n",
2004+
" 200: 1405.1825937805977,\n",
2005+
" 300: 1348.6418726014172,\n",
2006+
" 400: 992.3726336890564},\n",
2007+
" ('Thu', 200): {0: 1715.4114166813265,\n",
2008+
" 100: 1833.5195722683234,\n",
2009+
" 200: 1741.2757203880324,\n",
2010+
" 300: 1376.7551643904483,\n",
2011+
" 400: 957.9607707339657},\n",
2012+
" ('Thu', 300): {0: 2190.19649451877,\n",
2013+
" 100: 2125.740810274669,\n",
2014+
" 200: 1776.8132567876999,\n",
2015+
" 300: 1408.5495730824664,\n",
2016+
" 400: 990.4018172404869},\n",
2017+
" ('Fri', 0): {0: 0.0,\n",
2018+
" 100: 299.99999999999716,\n",
2019+
" 200: 406.0476593550646,\n",
2020+
" 300: 170.46122765548887,\n",
2021+
" 400: -153.64846976857817},\n",
2022+
" ('Fri', 100): {0: 699.9999999999943,\n",
2023+
" 100: 842.0141106267022,\n",
2024+
" 200: 610.5115569281422,\n",
2025+
" 300: 292.160669622827,\n",
2026+
" 400: -113.12406224669776},\n",
2027+
" ('Fri', 200): {0: 1172.1819744094662,\n",
2028+
" 100: 1070.907334160906,\n",
2029+
" 200: 687.7773470555264,\n",
2030+
" 300: 330.44001007014674,\n",
2031+
" 400: -75.1010831966216},\n",
2032+
" ('Fri', 300): {0: 1427.1955441761681,\n",
2033+
" 100: 1007.1503466766485,\n",
2034+
" 200: 674.4671172275836,\n",
2035+
" 300: 278.1467797475504,\n",
2036+
" 400: -99.78074377598806},\n",
2037+
" ('Weekend', 0): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
2038+
" ('Weekend', 100): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
2039+
" ('Weekend', 200): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n",
2040+
" ('Weekend', 300): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0}})"
2041+
]
2042+
},
2043+
"execution_count": 124,
2044+
"metadata": {},
2045+
"output_type": "execute_result"
2046+
}
2047+
],
2048+
"source": [
2049+
"q_learning(foodtruck, 1, 0.1, 0.01, 20000000)\n"
2050+
]
2051+
},
2052+
{
2053+
"cell_type": "code",
2054+
"execution_count": null,
2055+
"metadata": {},
2056+
"outputs": [],
2057+
"source": []
2058+
},
2059+
{
2060+
"cell_type": "code",
2061+
"execution_count": null,
2062+
"metadata": {},
2063+
"outputs": [],
2064+
"source": [
2065+
"Q"
2066+
]
2067+
},
2068+
{
2069+
"cell_type": "markdown",
2070+
"metadata": {},
2071+
"source": [
2072+
"{('Mon', 0): 2880.0,\n",
2073+
" ('Tue', 0): 2250.0,\n",
2074+
" ('Tue', 100): 2650.0,\n",
2075+
" ('Tue', 200): 3050.0,\n",
2076+
" ('Tue', 300): 3450.0,\n",
2077+
" ('Wed', 0): 1620.0,\n",
2078+
" ('Wed', 100): 2020.0,\n",
2079+
" ('Wed', 200): 2420.0,\n",
2080+
" ('Wed', 300): 2820.0,\n",
2081+
" ('Thu', 0): 990.0,\n",
2082+
" ('Thu', 100): 1390.0,\n",
2083+
" ('Thu', 200): 1790.0,\n",
2084+
" ('Thu', 300): 2190.0,\n",
2085+
" ('Fri', 0): 390.00000000000006,\n",
2086+
" ('Fri', 100): 790.0000000000001,\n",
2087+
" ('Fri', 200): 1190.0,\n",
2088+
" ('Fri', 300): 1400.0,\n",
2089+
" ('Weekend', 0): 0,\n",
2090+
" ('Weekend', 100): 0,\n",
2091+
" ('Weekend', 200): 0,\n",
2092+
" ('Weekend', 300): 0}"
2093+
]
2094+
}
2095+
],
2096+
"metadata": {
2097+
"kernelspec": {
2098+
"display_name": "py37ml",
2099+
"language": "python",
2100+
"name": "py37ml"
2101+
},
2102+
"language_info": {
2103+
"codemirror_mode": {
2104+
"name": "ipython",
2105+
"version": 3
2106+
},
2107+
"file_extension": ".py",
2108+
"mimetype": "text/x-python",
2109+
"name": "python",
2110+
"nbconvert_exporter": "python",
2111+
"pygments_lexer": "ipython3",
2112+
"version": "3.7.4"
2113+
}
2114+
},
2115+
"nbformat": 4,
2116+
"nbformat_minor": 2
2117+
}

0 commit comments

Comments
 (0)
Please sign in to comment.