|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 1, |
| 6 | + "metadata": {}, |
| 7 | + "outputs": [], |
| 8 | + "source": [ |
| 9 | + "import numpy as np\n", |
| 10 | + "import gym" |
| 11 | + ] |
| 12 | + }, |
| 13 | + { |
| 14 | + "cell_type": "markdown", |
| 15 | + "metadata": {}, |
| 16 | + "source": [ |
| 17 | + "# Dynamic Programming" |
| 18 | + ] |
| 19 | + }, |
| 20 | + { |
| 21 | + "cell_type": "code", |
| 22 | + "execution_count": 18, |
| 23 | + "metadata": {}, |
| 24 | + "outputs": [], |
| 25 | + "source": [ |
| 26 | + "class FoodTruck(gym.Env):\n", |
| 27 | + " def __init__(self):\n", |
| 28 | + " self.v_demand = [100, 200, 300, 400]\n", |
| 29 | + " self.p_demand = [0.3, 0.4, 0.2, 0.1]\n", |
| 30 | + " self.capacity = self.v_demand[-1]\n", |
| 31 | + " self.days = ['Mon', 'Tue', 'Wed', \n", |
| 32 | + " 'Thu', 'Fri', \"Weekend\"]\n", |
| 33 | + " self.unit_cost = 4\n", |
| 34 | + " self.net_revenue = 7\n", |
| 35 | + " self.action_space = [0, 100, 200, 300, 400]\n", |
| 36 | + " self.state_space = [(\"Mon\", 0)] \\\n", |
| 37 | + " + [(d, i) for d in self.days[1:] \n", |
| 38 | + " for i in [0, 100, 200, 300]]\n", |
| 39 | + " \n", |
| 40 | + " def get_next_state_reward(self, state, action, demand):\n", |
| 41 | + " day, inventory = state\n", |
| 42 | + " result = {}\n", |
| 43 | + " result['next_day'] = self.days[self.days.index(day) \\\n", |
| 44 | + " + 1]\n", |
| 45 | + " result['starting_inventory'] = min(self.capacity, \n", |
| 46 | + " inventory \n", |
| 47 | + " + action)\n", |
| 48 | + " result['cost'] = self.unit_cost * action \n", |
| 49 | + " result['sales'] = min(result['starting_inventory'], \n", |
| 50 | + " demand)\n", |
| 51 | + " result['revenue'] = self.net_revenue * result['sales']\n", |
| 52 | + " result['next_inventory'] \\\n", |
| 53 | + " = result['starting_inventory'] - result['sales']\n", |
| 54 | + " result['reward'] = result['revenue'] - result['cost']\n", |
| 55 | + " return result\n", |
| 56 | + " \n", |
| 57 | + " def get_transition_prob(self, state, action):\n", |
| 58 | + " next_s_r_prob = {}\n", |
| 59 | + " for ix, demand in enumerate(self.v_demand):\n", |
| 60 | + " result = self.get_next_state_reward(state, \n", |
| 61 | + " action, \n", |
| 62 | + " demand)\n", |
| 63 | + " next_s = (result['next_day'],\n", |
| 64 | + " result['next_inventory'])\n", |
| 65 | + " reward = result['reward']\n", |
| 66 | + " prob = self.p_demand[ix]\n", |
| 67 | + " if (next_s, reward) not in next_s_r_prob:\n", |
| 68 | + " next_s_r_prob[next_s, reward] = prob\n", |
| 69 | + " else:\n", |
| 70 | + " next_s_r_prob[next_s, reward] += prob\n", |
| 71 | + " return next_s_r_prob\n", |
| 72 | + " \n", |
| 73 | + " def reset(self):\n", |
| 74 | + " self.day = \"Mon\"\n", |
| 75 | + " self.inventory = 0\n", |
| 76 | + " state = (self.day, self.inventory)\n", |
| 77 | + " return state\n", |
| 78 | + " \n", |
| 79 | + " def is_terminal(self, state):\n", |
| 80 | + " day, inventory = state\n", |
| 81 | + " if day == \"Weekend\":\n", |
| 82 | + " return True\n", |
| 83 | + " else:\n", |
| 84 | + " return False\n", |
| 85 | + " \n", |
| 86 | + " def step(self, action):\n", |
| 87 | + " demand = np.random.choice(self.v_demand, \n", |
| 88 | + " p=self.p_demand)\n", |
| 89 | + " result = self.get_next_state_reward((self.day, \n", |
| 90 | + " self.inventory), \n", |
| 91 | + " action, \n", |
| 92 | + " demand)\n", |
| 93 | + " self.day = result['next_day']\n", |
| 94 | + " self.inventory = result['next_inventory']\n", |
| 95 | + " state = (self.day, self.inventory)\n", |
| 96 | + " reward = result['reward']\n", |
| 97 | + " done = self.is_terminal(state)\n", |
| 98 | + " info = {'demand': demand, 'sales': result['sales']}\n", |
| 99 | + " return state, reward, done, info" |
| 100 | + ] |
| 101 | + }, |
| 102 | + { |
| 103 | + "cell_type": "code", |
| 104 | + "execution_count": 19, |
| 105 | + "metadata": {}, |
| 106 | + "outputs": [ |
| 107 | + { |
| 108 | + "data": { |
| 109 | + "text/plain": [ |
| 110 | + "2590.83" |
| 111 | + ] |
| 112 | + }, |
| 113 | + "execution_count": 19, |
| 114 | + "metadata": {}, |
| 115 | + "output_type": "execute_result" |
| 116 | + } |
| 117 | + ], |
| 118 | + "source": [ |
| 119 | + "# Simulating an arbitrary policy\n", |
| 120 | + "np.random.seed(0)\n", |
| 121 | + "foodtruck = FoodTruck()\n", |
| 122 | + "rewards = []\n", |
| 123 | + "for i_episode in range(10000):\n", |
| 124 | + " state = foodtruck.reset()\n", |
| 125 | + " done = False\n", |
| 126 | + " ep_reward = 0\n", |
| 127 | + " while not done:\n", |
| 128 | + " day, inventory = state\n", |
| 129 | + " action = max(0, 300 - inventory)\n", |
| 130 | + " state, reward, done, info = foodtruck.step(action) \n", |
| 131 | + " ep_reward += reward\n", |
| 132 | + " rewards.append(ep_reward)\n", |
| 133 | + "np.mean(rewards)" |
| 134 | + ] |
| 135 | + }, |
| 136 | + { |
| 137 | + "cell_type": "code", |
| 138 | + "execution_count": null, |
| 139 | + "metadata": {}, |
| 140 | + "outputs": [], |
| 141 | + "source": [ |
| 142 | + "# Single day expected reward\n", |
| 143 | + "ucost = 4\n", |
| 144 | + "uprice = 7\n", |
| 145 | + "v_demand = [100, 200, 300, 400]\n", |
| 146 | + "p_demand = [0.3, 0.4, 0.2, 0.1]\n", |
| 147 | + "inv = 400\n", |
| 148 | + "profit = uprice*np.sum([p_demand[i]*min(v_demand[i], inv) for i in range(4)]) - inv*ucost\n", |
| 149 | + "print(profit)" |
| 150 | + ] |
| 151 | + }, |
| 152 | + { |
| 153 | + "cell_type": "markdown", |
| 154 | + "metadata": {}, |
| 155 | + "source": [ |
| 156 | + "## Policy Evaluation" |
| 157 | + ] |
| 158 | + }, |
| 159 | + { |
| 160 | + "cell_type": "code", |
| 161 | + "execution_count": 21, |
| 162 | + "metadata": {}, |
| 163 | + "outputs": [], |
| 164 | + "source": [ |
| 165 | + "def base_policy(states):\n", |
| 166 | + " policy = {}\n", |
| 167 | + " for s in states:\n", |
| 168 | + " day, inventory = s\n", |
| 169 | + " prob_a = {} \n", |
| 170 | + " if inventory >= 300:\n", |
| 171 | + " prob_a[0] = 1\n", |
| 172 | + " else:\n", |
| 173 | + " prob_a[200 - inventory] = 0.5\n", |
| 174 | + " prob_a[300 - inventory] = 0.5\n", |
| 175 | + " policy[s] = prob_a\n", |
| 176 | + " return policy" |
| 177 | + ] |
| 178 | + }, |
| 179 | + { |
| 180 | + "cell_type": "code", |
| 181 | + "execution_count": 22, |
| 182 | + "metadata": {}, |
| 183 | + "outputs": [], |
| 184 | + "source": [ |
| 185 | + "def expected_update(env, v, s, prob_a, gamma):\n", |
| 186 | + " expected_value = 0\n", |
| 187 | + " for a in prob_a:\n", |
| 188 | + " prob_next_s_r = env.get_transition_prob(s, a)\n", |
| 189 | + " for next_s, r in prob_next_s_r:\n", |
| 190 | + " expected_value += prob_a[a] \\\n", |
| 191 | + " * prob_next_s_r[next_s, r] \\\n", |
| 192 | + " * (r + gamma * v[next_s])\n", |
| 193 | + " return expected_value" |
| 194 | + ] |
| 195 | + }, |
| 196 | + { |
| 197 | + "cell_type": "code", |
| 198 | + "execution_count": 26, |
| 199 | + "metadata": {}, |
| 200 | + "outputs": [], |
| 201 | + "source": [ |
| 202 | + "def policy_evaluation(env, policy, max_iter=100, \n", |
| 203 | + " v = None, eps=0.1, gamma=1):\n", |
| 204 | + " if not v:\n", |
| 205 | + " v = {s: 0 for s in env.state_space}\n", |
| 206 | + " k = 0\n", |
| 207 | + " while True:\n", |
| 208 | + " max_delta = 0\n", |
| 209 | + " for s in v:\n", |
| 210 | + " if not env.is_terminal(s):\n", |
| 211 | + " v_old = v[s]\n", |
| 212 | + " prob_a = policy[s]\n", |
| 213 | + " v[s] = expected_update(env, v, \n", |
| 214 | + " s, prob_a, \n", |
| 215 | + " gamma)\n", |
| 216 | + " max_delta = max(max_delta, \n", |
| 217 | + " abs(v[s] - v_old))\n", |
| 218 | + " k += 1\n", |
| 219 | + " if max_delta < eps:\n", |
| 220 | + " print(\"Converged in\", k, \"iterations.\")\n", |
| 221 | + " break\n", |
| 222 | + " elif k == max_iter:\n", |
| 223 | + " print(\"Terminating after\", k, \"iterations.\")\n", |
| 224 | + " break\n", |
| 225 | + " return v" |
| 226 | + ] |
| 227 | + }, |
| 228 | + { |
| 229 | + "cell_type": "code", |
| 230 | + "execution_count": 52, |
| 231 | + "metadata": {}, |
| 232 | + "outputs": [], |
| 233 | + "source": [ |
| 234 | + "foodtruck = FoodTruck()\n", |
| 235 | + "policy = base_policy(foodtruck.state_space)" |
| 236 | + ] |
| 237 | + }, |
| 238 | + { |
| 239 | + "cell_type": "code", |
| 240 | + "execution_count": 53, |
| 241 | + "metadata": {}, |
| 242 | + "outputs": [ |
| 243 | + { |
| 244 | + "name": "stdout", |
| 245 | + "output_type": "stream", |
| 246 | + "text": [ |
| 247 | + "Converged in 6 iterations.\n", |
| 248 | + "Expected weekly profit: 2515.0\n" |
| 249 | + ] |
| 250 | + } |
| 251 | + ], |
| 252 | + "source": [ |
| 253 | + "v = policy_evaluation(foodtruck, policy)\n", |
| 254 | + "print(\"Expected weekly profit:\", v[\"Mon\", 0])" |
| 255 | + ] |
| 256 | + }, |
| 257 | + { |
| 258 | + "cell_type": "code", |
| 259 | + "execution_count": 54, |
| 260 | + "metadata": {}, |
| 261 | + "outputs": [ |
| 262 | + { |
| 263 | + "name": "stdout", |
| 264 | + "output_type": "stream", |
| 265 | + "text": [ |
| 266 | + "The state values:\n" |
| 267 | + ] |
| 268 | + }, |
| 269 | + { |
| 270 | + "data": { |
| 271 | + "text/plain": [ |
| 272 | + "{('Mon', 0): 2515.0,\n", |
| 273 | + " ('Tue', 0): 1960.0,\n", |
| 274 | + " ('Tue', 100): 2360.0,\n", |
| 275 | + " ('Tue', 200): 2760.0,\n", |
| 276 | + " ('Tue', 300): 3205.0,\n", |
| 277 | + " ('Wed', 0): 1405.0,\n", |
| 278 | + " ('Wed', 100): 1805.0,\n", |
| 279 | + " ('Wed', 200): 2205.0,\n", |
| 280 | + " ('Wed', 300): 2650.0,\n", |
| 281 | + " ('Thu', 0): 850.0000000000001,\n", |
| 282 | + " ('Thu', 100): 1250.0,\n", |
| 283 | + " ('Thu', 200): 1650.0,\n", |
| 284 | + " ('Thu', 300): 2095.0,\n", |
| 285 | + " ('Fri', 0): 295.00000000000006,\n", |
| 286 | + " ('Fri', 100): 695.0000000000001,\n", |
| 287 | + " ('Fri', 200): 1095.0,\n", |
| 288 | + " ('Fri', 300): 1400.0,\n", |
| 289 | + " ('Weekend', 0): 0,\n", |
| 290 | + " ('Weekend', 100): 0,\n", |
| 291 | + " ('Weekend', 200): 0,\n", |
| 292 | + " ('Weekend', 300): 0}" |
| 293 | + ] |
| 294 | + }, |
| 295 | + "execution_count": 54, |
| 296 | + "metadata": {}, |
| 297 | + "output_type": "execute_result" |
| 298 | + } |
| 299 | + ], |
| 300 | + "source": [ |
| 301 | + "print(\"The state values:\")\n", |
| 302 | + "v" |
| 303 | + ] |
| 304 | + }, |
| 305 | + { |
| 306 | + "cell_type": "code", |
| 307 | + "execution_count": 57, |
| 308 | + "metadata": {}, |
| 309 | + "outputs": [], |
| 310 | + "source": [ |
| 311 | + "def choose_action(state, policy):\n", |
| 312 | + " prob_a = policy[state]\n", |
| 313 | + " action = np.random.choice(a=list(prob_a.keys()), \n", |
| 314 | + " p=list(prob_a.values()))\n", |
| 315 | + " return action\n", |
| 316 | + "\n", |
| 317 | + "def simulate_policy(policy, n_episodes):\n", |
| 318 | + " np.random.seed(0)\n", |
| 319 | + " foodtruck = FoodTruck()\n", |
| 320 | + " rewards = []\n", |
| 321 | + " for i_episode in range(n_episodes):\n", |
| 322 | + " state = foodtruck.reset()\n", |
| 323 | + " done = False\n", |
| 324 | + " ep_reward = 0\n", |
| 325 | + " while not done:\n", |
| 326 | + " action = choose_action(state, policy)\n", |
| 327 | + " state, reward, done, info = foodtruck.step(action) \n", |
| 328 | + " ep_reward += reward\n", |
| 329 | + " rewards.append(ep_reward)\n", |
| 330 | + " print(\"Expected weekly profit:\", np.mean(rewards))" |
| 331 | + ] |
| 332 | + }, |
| 333 | + { |
| 334 | + "cell_type": "code", |
| 335 | + "execution_count": 58, |
| 336 | + "metadata": {}, |
| 337 | + "outputs": [ |
| 338 | + { |
| 339 | + "name": "stdout", |
| 340 | + "output_type": "stream", |
| 341 | + "text": [ |
| 342 | + "Expected weekly profit: 2518.1\n" |
| 343 | + ] |
| 344 | + } |
| 345 | + ], |
| 346 | + "source": [ |
| 347 | + "simulate_policy(policy, 1000)" |
| 348 | + ] |
| 349 | + }, |
| 350 | + { |
| 351 | + "cell_type": "markdown", |
| 352 | + "metadata": {}, |
| 353 | + "source": [ |
| 354 | + "## Policy Iteration" |
| 355 | + ] |
| 356 | + }, |
| 357 | + { |
| 358 | + "cell_type": "code", |
| 359 | + "execution_count": 39, |
| 360 | + "metadata": {}, |
| 361 | + "outputs": [], |
| 362 | + "source": [ |
| 363 | + "def policy_improvement(env, v, s, actions, gamma):\n", |
| 364 | + " prob_a = {}\n", |
| 365 | + " if not env.is_terminal(s):\n", |
| 366 | + " max_q = np.NINF\n", |
| 367 | + " best_a = None\n", |
| 368 | + " for a in actions:\n", |
| 369 | + " q_sa = expected_update(env, v, s, {a: 1}, gamma)\n", |
| 370 | + " if q_sa >= max_q:\n", |
| 371 | + " max_q = q_sa\n", |
| 372 | + " best_a = a\n", |
| 373 | + " prob_a[best_a] = 1\n", |
| 374 | + " else:\n", |
| 375 | + " max_q = 0\n", |
| 376 | + " return prob_a, max_q" |
| 377 | + ] |
| 378 | + }, |
| 379 | + { |
| 380 | + "cell_type": "code", |
| 381 | + "execution_count": 42, |
| 382 | + "metadata": {}, |
| 383 | + "outputs": [], |
| 384 | + "source": [ |
| 385 | + "def policy_iteration(env, eps=0.1, gamma=1):\n", |
| 386 | + " np.random.seed(1)\n", |
| 387 | + " states = env.state_space\n", |
| 388 | + " actions = env.action_space\n", |
| 389 | + " policy = {s: {np.random.choice(actions): 1}\n", |
| 390 | + " for s in states}\n", |
| 391 | + " v = {s: 0 for s in states}\n", |
| 392 | + " while True:\n", |
| 393 | + " v = policy_evaluation(env, policy, v=v, \n", |
| 394 | + " eps=eps, gamma=gamma)\n", |
| 395 | + " old_policy = policy\n", |
| 396 | + " policy = {}\n", |
| 397 | + " for s in states:\n", |
| 398 | + " policy[s], _ = policy_improvement(env, v, s, \n", |
| 399 | + " actions, gamma)\n", |
| 400 | + " if old_policy == policy:\n", |
| 401 | + " break\n", |
| 402 | + " print(\"Optimal policy found!\")\n", |
| 403 | + " return policy, v" |
| 404 | + ] |
| 405 | + }, |
| 406 | + { |
| 407 | + "cell_type": "code", |
| 408 | + "execution_count": 43, |
| 409 | + "metadata": {}, |
| 410 | + "outputs": [ |
| 411 | + { |
| 412 | + "name": "stdout", |
| 413 | + "output_type": "stream", |
| 414 | + "text": [ |
| 415 | + "Converged in 6 iterations.\n", |
| 416 | + "Converged in 6 iterations.\n", |
| 417 | + "Converged in 5 iterations.\n", |
| 418 | + "Optimal policy found!\n", |
| 419 | + "Expected weekly profit: 2880.0\n" |
| 420 | + ] |
| 421 | + } |
| 422 | + ], |
| 423 | + "source": [ |
| 424 | + "policy, v = policy_iteration(foodtruck)\n", |
| 425 | + "print(\"Expected weekly profit:\", v[\"Mon\", 0])" |
| 426 | + ] |
| 427 | + }, |
| 428 | + { |
| 429 | + "cell_type": "code", |
| 430 | + "execution_count": 44, |
| 431 | + "metadata": {}, |
| 432 | + "outputs": [ |
| 433 | + { |
| 434 | + "name": "stdout", |
| 435 | + "output_type": "stream", |
| 436 | + "text": [ |
| 437 | + "{('Mon', 0): {400: 1}, ('Tue', 0): {400: 1}, ('Tue', 100): {300: 1}, ('Tue', 200): {200: 1}, ('Tue', 300): {100: 1}, ('Wed', 0): {400: 1}, ('Wed', 100): {300: 1}, ('Wed', 200): {200: 1}, ('Wed', 300): {100: 1}, ('Thu', 0): {300: 1}, ('Thu', 100): {200: 1}, ('Thu', 200): {100: 1}, ('Thu', 300): {0: 1}, ('Fri', 0): {200: 1}, ('Fri', 100): {100: 1}, ('Fri', 200): {0: 1}, ('Fri', 300): {0: 1}, ('Weekend', 0): {}, ('Weekend', 100): {}, ('Weekend', 200): {}, ('Weekend', 300): {}}\n" |
| 438 | + ] |
| 439 | + } |
| 440 | + ], |
| 441 | + "source": [ |
| 442 | + "print(policy)" |
| 443 | + ] |
| 444 | + }, |
| 445 | + { |
| 446 | + "cell_type": "code", |
| 447 | + "execution_count": null, |
| 448 | + "metadata": {}, |
| 449 | + "outputs": [], |
| 450 | + "source": [] |
| 451 | + }, |
| 452 | + { |
| 453 | + "cell_type": "markdown", |
| 454 | + "metadata": {}, |
| 455 | + "source": [ |
| 456 | + "## Value Iteration" |
| 457 | + ] |
| 458 | + }, |
| 459 | + { |
| 460 | + "cell_type": "code", |
| 461 | + "execution_count": 49, |
| 462 | + "metadata": {}, |
| 463 | + "outputs": [], |
| 464 | + "source": [ |
| 465 | + "def value_iteration(env, max_iter=100, eps=0.1, gamma=1):\n", |
| 466 | + " states = env.state_space\n", |
| 467 | + " actions = env.action_space\n", |
| 468 | + " v = {s: 0 for s in states}\n", |
| 469 | + " policy = {}\n", |
| 470 | + " k = 0\n", |
| 471 | + " while True:\n", |
| 472 | + " max_delta = 0\n", |
| 473 | + " for s in states:\n", |
| 474 | + " old_v = v[s]\n", |
| 475 | + " policy[s], v[s] = policy_improvement(env, \n", |
| 476 | + " v, \n", |
| 477 | + " s, \n", |
| 478 | + " actions, \n", |
| 479 | + " gamma)\n", |
| 480 | + " max_delta = max(max_delta, abs(v[s] - old_v))\n", |
| 481 | + " k += 1\n", |
| 482 | + " if max_delta < eps:\n", |
| 483 | + " print(\"Converged in\", k, \"iterations.\")\n", |
| 484 | + " break\n", |
| 485 | + " elif k == max_iter:\n", |
| 486 | + " print(\"Terminating after\", k, \"iterations.\")\n", |
| 487 | + " break\n", |
| 488 | + " return policy, v" |
| 489 | + ] |
| 490 | + }, |
| 491 | + { |
| 492 | + "cell_type": "code", |
| 493 | + "execution_count": 48, |
| 494 | + "metadata": {}, |
| 495 | + "outputs": [ |
| 496 | + { |
| 497 | + "name": "stdout", |
| 498 | + "output_type": "stream", |
| 499 | + "text": [ |
| 500 | + "Converged in 6 iterations.\n", |
| 501 | + "6\n", |
| 502 | + "Expected weekly profit: 2880.0\n" |
| 503 | + ] |
| 504 | + } |
| 505 | + ], |
| 506 | + "source": [ |
| 507 | + "policy, v = value_iteration(foodtruck)\n", |
| 508 | + "print(\"Expected weekly profit:\", v[\"Mon\", 0])" |
| 509 | + ] |
| 510 | + }, |
| 511 | + { |
| 512 | + "cell_type": "code", |
| 513 | + "execution_count": null, |
| 514 | + "metadata": {}, |
| 515 | + "outputs": [], |
| 516 | + "source": [ |
| 517 | + "print(policy)" |
| 518 | + ] |
| 519 | + }, |
| 520 | + { |
| 521 | + "cell_type": "code", |
| 522 | + "execution_count": null, |
| 523 | + "metadata": {}, |
| 524 | + "outputs": [], |
| 525 | + "source": [ |
| 526 | + "def generalized_policy_iteration(env, max_iter=2, eps=0.1, gamma=1):\n", |
| 527 | + " np.random.seed(1)\n", |
| 528 | + " states = env.observation_space\n", |
| 529 | + " actions = env.action_space\n", |
| 530 | + " policy = {s: {np.random.choice(actions): 1}\n", |
| 531 | + " for s in states}\n", |
| 532 | + " v = {s: 0 for s in states}\n", |
| 533 | + " k = 0\n", |
| 534 | + " while True:\n", |
| 535 | + " v_old = v.copy()\n", |
| 536 | + " policy = {}\n", |
| 537 | + " for s in states:\n", |
| 538 | + " policy[s], v[s] = policy_improvement(env, v, s, \n", |
| 539 | + " actions, gamma)\n", |
| 540 | + " v = policy_evaluation(env, policy, \n", |
| 541 | + " max_iter=max_iter, v=v, \n", |
| 542 | + " eps=eps, gamma=gamma)\n", |
| 543 | + " max_delta = np.amax([abs(v[s] - v_old[s]) for s in v])\n", |
| 544 | + " k += 1\n", |
| 545 | + " if max_delta < eps:\n", |
| 546 | + " print(\"GPI converged in\", k, \"iterations.\")\n", |
| 547 | + " print([abs(v[s] - v_old[s]) for s in v])\n", |
| 548 | + " break\n", |
| 549 | + " \n", |
| 550 | + " print(\"Optimal policy found!\")\n", |
| 551 | + " return policy, v" |
| 552 | + ] |
| 553 | + }, |
| 554 | + { |
| 555 | + "cell_type": "code", |
| 556 | + "execution_count": null, |
| 557 | + "metadata": {}, |
| 558 | + "outputs": [], |
| 559 | + "source": [ |
| 560 | + "policy, v = generalized_policy_iteration(foodtruck, max_iter=2, eps=0.1, gamma=1)" |
| 561 | + ] |
| 562 | + }, |
| 563 | + { |
| 564 | + "cell_type": "code", |
| 565 | + "execution_count": null, |
| 566 | + "metadata": {}, |
| 567 | + "outputs": [], |
| 568 | + "source": [ |
| 569 | + "print(\"Expected weekly profit:\", v[\"Mon\", 0])\n", |
| 570 | + "print(policy)" |
| 571 | + ] |
| 572 | + }, |
| 573 | + { |
| 574 | + "cell_type": "code", |
| 575 | + "execution_count": null, |
| 576 | + "metadata": {}, |
| 577 | + "outputs": [], |
| 578 | + "source": [ |
| 579 | + "v" |
| 580 | + ] |
| 581 | + }, |
| 582 | + { |
| 583 | + "cell_type": "markdown", |
| 584 | + "metadata": {}, |
| 585 | + "source": [ |
| 586 | + "# Monte Carlo Methods" |
| 587 | + ] |
| 588 | + }, |
| 589 | + { |
| 590 | + "cell_type": "markdown", |
| 591 | + "metadata": {}, |
| 592 | + "source": [ |
| 593 | + "## MC Prediction" |
| 594 | + ] |
| 595 | + }, |
| 596 | + { |
| 597 | + "cell_type": "code", |
| 598 | + "execution_count": 71, |
| 599 | + "metadata": {}, |
| 600 | + "outputs": [], |
| 601 | + "source": [ |
| 602 | + "def first_visit_return(returns, trajectory, gamma):\n", |
| 603 | + " G = 0\n", |
| 604 | + " T = len(trajectory) - 1\n", |
| 605 | + " for t, sar in enumerate(reversed(trajectory)):\n", |
| 606 | + " s, a, r = sar\n", |
| 607 | + " G = r + gamma * G\n", |
| 608 | + " first_visit = True\n", |
| 609 | + " for j in range(T - t):\n", |
| 610 | + " if s == trajectory[j][0]:\n", |
| 611 | + " first_visit = False\n", |
| 612 | + " if first_visit:\n", |
| 613 | + " if s in returns:\n", |
| 614 | + " returns[s].append(G)\n", |
| 615 | + " else:\n", |
| 616 | + " returns[s] = [G]\n", |
| 617 | + " return returns" |
| 618 | + ] |
| 619 | + }, |
| 620 | + { |
| 621 | + "cell_type": "code", |
| 622 | + "execution_count": 74, |
| 623 | + "metadata": {}, |
| 624 | + "outputs": [], |
| 625 | + "source": [ |
| 626 | + "def get_trajectory(env, policy):\n", |
| 627 | + " trajectory = []\n", |
| 628 | + " state = env.reset()\n", |
| 629 | + " done = False\n", |
| 630 | + " sar = [state]\n", |
| 631 | + " while not done:\n", |
| 632 | + " action = choose_action(state, policy)\n", |
| 633 | + " state, reward, done, info = env.step(action)\n", |
| 634 | + " sar.append(action)\n", |
| 635 | + " sar.append(reward)\n", |
| 636 | + " trajectory.append(sar)\n", |
| 637 | + " sar = [state]\n", |
| 638 | + " return trajectory" |
| 639 | + ] |
| 640 | + }, |
| 641 | + { |
| 642 | + "cell_type": "code", |
| 643 | + "execution_count": 75, |
| 644 | + "metadata": {}, |
| 645 | + "outputs": [], |
| 646 | + "source": [ |
| 647 | + "def first_visit_mc(env, policy, gamma, n_trajectories):\n", |
| 648 | + " np.random.seed(0)\n", |
| 649 | + " returns = {}\n", |
| 650 | + " v = {}\n", |
| 651 | + " for i in range(n_trajectories):\n", |
| 652 | + " trajectory = get_trajectory(env, policy)\n", |
| 653 | + " returns = first_visit_return(returns, \n", |
| 654 | + " trajectory, \n", |
| 655 | + " gamma)\n", |
| 656 | + " for s in env.state_space:\n", |
| 657 | + " if s in returns:\n", |
| 658 | + " v[s] = np.round(np.mean(returns[s]), 1)\n", |
| 659 | + " return v" |
| 660 | + ] |
| 661 | + }, |
| 662 | + { |
| 663 | + "cell_type": "code", |
| 664 | + "execution_count": 76, |
| 665 | + "metadata": {}, |
| 666 | + "outputs": [], |
| 667 | + "source": [ |
| 668 | + "foodtruck = FoodTruck()\n", |
| 669 | + "policy = base_policy(foodtruck.state_space)" |
| 670 | + ] |
| 671 | + }, |
| 672 | + { |
| 673 | + "cell_type": "code", |
| 674 | + "execution_count": 77, |
| 675 | + "metadata": {}, |
| 676 | + "outputs": [ |
| 677 | + { |
| 678 | + "data": { |
| 679 | + "text/plain": [ |
| 680 | + "{('Mon', 0): 2515.9,\n", |
| 681 | + " ('Tue', 0): 1959.1,\n", |
| 682 | + " ('Tue', 100): 2362.2,\n", |
| 683 | + " ('Tue', 200): 2765.2,\n", |
| 684 | + " ('Wed', 0): 1411.3,\n", |
| 685 | + " ('Wed', 100): 1804.2,\n", |
| 686 | + " ('Wed', 200): 2198.9,\n", |
| 687 | + " ('Thu', 0): 852.9,\n", |
| 688 | + " ('Thu', 100): 1265.4,\n", |
| 689 | + " ('Thu', 200): 1644.4,\n", |
| 690 | + " ('Fri', 0): 301.1,\n", |
| 691 | + " ('Fri', 100): 696.5,\n", |
| 692 | + " ('Fri', 200): 1097.2}" |
| 693 | + ] |
| 694 | + }, |
| 695 | + "execution_count": 77, |
| 696 | + "metadata": {}, |
| 697 | + "output_type": "execute_result" |
| 698 | + } |
| 699 | + ], |
| 700 | + "source": [ |
| 701 | + "v_est = first_visit_mc(foodtruck, policy, 1, 10000)\n", |
| 702 | + "v_est" |
| 703 | + ] |
| 704 | + }, |
| 705 | + { |
| 706 | + "cell_type": "code", |
| 707 | + "execution_count": 78, |
| 708 | + "metadata": {}, |
| 709 | + "outputs": [ |
| 710 | + { |
| 711 | + "name": "stdout", |
| 712 | + "output_type": "stream", |
| 713 | + "text": [ |
| 714 | + "Converged in 6 iterations.\n" |
| 715 | + ] |
| 716 | + } |
| 717 | + ], |
| 718 | + "source": [ |
| 719 | + "v_true = policy_evaluation(foodtruck, policy)" |
| 720 | + ] |
| 721 | + }, |
| 722 | + { |
| 723 | + "cell_type": "code", |
| 724 | + "execution_count": 63, |
| 725 | + "metadata": {}, |
| 726 | + "outputs": [ |
| 727 | + { |
| 728 | + "data": { |
| 729 | + "text/plain": [ |
| 730 | + "{('Mon', 0): 2515.0,\n", |
| 731 | + " ('Tue', 0): 1960.0,\n", |
| 732 | + " ('Tue', 100): 2360.0,\n", |
| 733 | + " ('Tue', 200): 2760.0,\n", |
| 734 | + " ('Tue', 300): 3205.0,\n", |
| 735 | + " ('Wed', 0): 1405.0,\n", |
| 736 | + " ('Wed', 100): 1805.0,\n", |
| 737 | + " ('Wed', 200): 2205.0,\n", |
| 738 | + " ('Wed', 300): 2650.0,\n", |
| 739 | + " ('Thu', 0): 850.0000000000001,\n", |
| 740 | + " ('Thu', 100): 1250.0,\n", |
| 741 | + " ('Thu', 200): 1650.0,\n", |
| 742 | + " ('Thu', 300): 2095.0,\n", |
| 743 | + " ('Fri', 0): 295.00000000000006,\n", |
| 744 | + " ('Fri', 100): 695.0000000000001,\n", |
| 745 | + " ('Fri', 200): 1095.0,\n", |
| 746 | + " ('Fri', 300): 1400.0,\n", |
| 747 | + " ('Weekend', 0): 0,\n", |
| 748 | + " ('Weekend', 100): 0,\n", |
| 749 | + " ('Weekend', 200): 0,\n", |
| 750 | + " ('Weekend', 300): 0}" |
| 751 | + ] |
| 752 | + }, |
| 753 | + "execution_count": 63, |
| 754 | + "metadata": {}, |
| 755 | + "output_type": "execute_result" |
| 756 | + } |
| 757 | + ], |
| 758 | + "source": [ |
| 759 | + "v_true" |
| 760 | + ] |
| 761 | + }, |
| 762 | + { |
| 763 | + "cell_type": "code", |
| 764 | + "execution_count": null, |
| 765 | + "metadata": {}, |
| 766 | + "outputs": [], |
| 767 | + "source": [ |
| 768 | + "# v_est = first_visit_mc(foodtruck, policy, 1, 5)\n", |
| 769 | + "# {s: v_est[s] for s in sorted(v_est)}" |
| 770 | + ] |
| 771 | + }, |
| 772 | + { |
| 773 | + "cell_type": "code", |
| 774 | + "execution_count": null, |
| 775 | + "metadata": {}, |
| 776 | + "outputs": [], |
| 777 | + "source": [ |
| 778 | + "# v_est = first_visit_mc(foodtruck, policy, 1, 10)\n", |
| 779 | + "# {s: v_est[s] for s in sorted(v_est)}" |
| 780 | + ] |
| 781 | + }, |
| 782 | + { |
| 783 | + "cell_type": "code", |
| 784 | + "execution_count": null, |
| 785 | + "metadata": {}, |
| 786 | + "outputs": [], |
| 787 | + "source": [ |
| 788 | + "# v_est = first_visit_mc(foodtruck, policy, 1, 100)\n", |
| 789 | + "# {s: v_est[s] for s in sorted(v_est)}" |
| 790 | + ] |
| 791 | + }, |
| 792 | + { |
| 793 | + "cell_type": "code", |
| 794 | + "execution_count": null, |
| 795 | + "metadata": {}, |
| 796 | + "outputs": [], |
| 797 | + "source": [ |
| 798 | + "# v_est = first_visit_mc(foodtruck, policy, 1, 1000)\n", |
| 799 | + "# {s: v_est[s] for s in sorted(v_est)}" |
| 800 | + ] |
| 801 | + }, |
| 802 | + { |
| 803 | + "cell_type": "code", |
| 804 | + "execution_count": null, |
| 805 | + "metadata": {}, |
| 806 | + "outputs": [], |
| 807 | + "source": [ |
| 808 | + "# v_est = first_visit_mc(foodtruck, policy, 1, 10000)\n", |
| 809 | + "# {s: v_est[s] for s in sorted(v_est)}" |
| 810 | + ] |
| 811 | + }, |
| 812 | + { |
| 813 | + "cell_type": "markdown", |
| 814 | + "metadata": {}, |
| 815 | + "source": [ |
| 816 | + "## On-policy Monte Carlo Control" |
| 817 | + ] |
| 818 | + }, |
| 819 | + { |
| 820 | + "cell_type": "code", |
| 821 | + "execution_count": 85, |
| 822 | + "metadata": {}, |
| 823 | + "outputs": [], |
| 824 | + "source": [ |
| 825 | + "import operator" |
| 826 | + ] |
| 827 | + }, |
| 828 | + { |
| 829 | + "cell_type": "code", |
| 830 | + "execution_count": 91, |
| 831 | + "metadata": {}, |
| 832 | + "outputs": [], |
| 833 | + "source": [ |
| 834 | + "def get_eps_greedy(actions, eps, a_best):\n", |
| 835 | + " prob_a = {}\n", |
| 836 | + " n_a = len(actions)\n", |
| 837 | + " for a in actions:\n", |
| 838 | + " if a == a_best:\n", |
| 839 | + " prob_a[a] = 1 - eps + eps/n_a\n", |
| 840 | + " else:\n", |
| 841 | + " prob_a[a] = eps/n_a\n", |
| 842 | + " return prob_a" |
| 843 | + ] |
| 844 | + }, |
| 845 | + { |
| 846 | + "cell_type": "code", |
| 847 | + "execution_count": null, |
| 848 | + "metadata": {}, |
| 849 | + "outputs": [], |
| 850 | + "source": [] |
| 851 | + }, |
| 852 | + { |
| 853 | + "cell_type": "code", |
| 854 | + "execution_count": 92, |
| 855 | + "metadata": {}, |
| 856 | + "outputs": [], |
| 857 | + "source": [ |
| 858 | + "def get_random_policy(states, actions):\n", |
| 859 | + " policy = {}\n", |
| 860 | + " n_a = len(actions)\n", |
| 861 | + " for s in states:\n", |
| 862 | + " policy[s] = {a: 1/n_a for a in actions}\n", |
| 863 | + " return policy" |
| 864 | + ] |
| 865 | + }, |
| 866 | + { |
| 867 | + "cell_type": "code", |
| 868 | + "execution_count": 93, |
| 869 | + "metadata": {}, |
| 870 | + "outputs": [], |
| 871 | + "source": [ |
| 872 | + "def on_policy_first_visit_mc(env, n_iter, eps, gamma):\n", |
| 873 | + " np.random.seed(0)\n", |
| 874 | + " states = env.state_space\n", |
| 875 | + " actions = env.action_space\n", |
| 876 | + " policy = get_random_policy(states, actions)\n", |
| 877 | + " Q = {s: {a: 0 for a in actions} for s in states}\n", |
| 878 | + " Q_n = {s: {a: 0 for a in actions} for s in states}\n", |
| 879 | + " for i in range(n_iter):\n", |
| 880 | + " if i % 10000 == 0:\n", |
| 881 | + " print(\"Iteration:\", i)\n", |
| 882 | + " trajectory = get_trajectory(env, policy)\n", |
| 883 | + " G = 0\n", |
| 884 | + " T = len(trajectory) - 1\n", |
| 885 | + " for t, sar in enumerate(reversed(trajectory)):\n", |
| 886 | + " s, a, r = sar\n", |
| 887 | + " G = r + gamma * G\n", |
| 888 | + " first_visit = True\n", |
| 889 | + " for j in range(T - t):\n", |
| 890 | + " s_j = trajectory[j][0]\n", |
| 891 | + " a_j = trajectory[j][1]\n", |
| 892 | + " if (s, a) == (s_j, a_j):\n", |
| 893 | + " first_visit = False\n", |
| 894 | + " if first_visit:\n", |
| 895 | + " Q[s][a] = Q_n[s][a] * Q[s][a] + G\n", |
| 896 | + " Q_n[s][a] += 1\n", |
| 897 | + " Q[s][a] /= Q_n[s][a]\n", |
| 898 | + " a_best = max(Q[s].items(), \n", |
| 899 | + " key=operator.itemgetter(1))[0]\n", |
| 900 | + " policy[s] = get_eps_greedy(actions, \n", |
| 901 | + " eps, \n", |
| 902 | + " a_best)\n", |
| 903 | + " return policy, Q, Q_n" |
| 904 | + ] |
| 905 | + }, |
| 906 | + { |
| 907 | + "cell_type": "code", |
| 908 | + "execution_count": 94, |
| 909 | + "metadata": {}, |
| 910 | + "outputs": [ |
| 911 | + { |
| 912 | + "name": "stdout", |
| 913 | + "output_type": "stream", |
| 914 | + "text": [ |
| 915 | + "Iteration: 0\n", |
| 916 | + "Iteration: 10000\n", |
| 917 | + "Iteration: 20000\n", |
| 918 | + "Iteration: 30000\n", |
| 919 | + "Iteration: 40000\n", |
| 920 | + "Iteration: 50000\n", |
| 921 | + "Iteration: 60000\n", |
| 922 | + "Iteration: 70000\n", |
| 923 | + "Iteration: 80000\n", |
| 924 | + "Iteration: 90000\n", |
| 925 | + "Iteration: 100000\n", |
| 926 | + "Iteration: 110000\n", |
| 927 | + "Iteration: 120000\n", |
| 928 | + "Iteration: 130000\n", |
| 929 | + "Iteration: 140000\n", |
| 930 | + "Iteration: 150000\n", |
| 931 | + "Iteration: 160000\n", |
| 932 | + "Iteration: 170000\n", |
| 933 | + "Iteration: 180000\n", |
| 934 | + "Iteration: 190000\n", |
| 935 | + "Iteration: 200000\n", |
| 936 | + "Iteration: 210000\n", |
| 937 | + "Iteration: 220000\n", |
| 938 | + "Iteration: 230000\n", |
| 939 | + "Iteration: 240000\n", |
| 940 | + "Iteration: 250000\n", |
| 941 | + "Iteration: 260000\n", |
| 942 | + "Iteration: 270000\n", |
| 943 | + "Iteration: 280000\n", |
| 944 | + "Iteration: 290000\n" |
| 945 | + ] |
| 946 | + } |
| 947 | + ], |
| 948 | + "source": [ |
| 949 | + "policy, Q, Q_n = on_policy_first_visit_mc(foodtruck, \n", |
| 950 | + " 300000, \n", |
| 951 | + " 0.05, \n", |
| 952 | + " 1)" |
| 953 | + ] |
| 954 | + }, |
| 955 | + { |
| 956 | + "cell_type": "code", |
| 957 | + "execution_count": 90, |
| 958 | + "metadata": {}, |
| 959 | + "outputs": [ |
| 960 | + { |
| 961 | + "data": { |
| 962 | + "text/plain": [ |
| 963 | + "{('Mon', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.96},\n", |
| 964 | + " ('Tue', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.96},\n", |
| 965 | + " ('Tue', 100): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.96, 400: 0.01},\n", |
| 966 | + " ('Tue', 200): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n", |
| 967 | + " ('Tue', 300): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n", |
| 968 | + " ('Wed', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.96},\n", |
| 969 | + " ('Wed', 100): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.96, 400: 0.01},\n", |
| 970 | + " ('Wed', 200): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n", |
| 971 | + " ('Wed', 300): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n", |
| 972 | + " ('Thu', 0): {0: 0.01, 100: 0.01, 200: 0.01, 300: 0.96, 400: 0.01},\n", |
| 973 | + " ('Thu', 100): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n", |
| 974 | + " ('Thu', 200): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n", |
| 975 | + " ('Thu', 300): {0: 0.96, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.01},\n", |
| 976 | + " ('Fri', 0): {0: 0.01, 100: 0.01, 200: 0.96, 300: 0.01, 400: 0.01},\n", |
| 977 | + " ('Fri', 100): {0: 0.01, 100: 0.96, 200: 0.01, 300: 0.01, 400: 0.01},\n", |
| 978 | + " ('Fri', 200): {0: 0.96, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.01},\n", |
| 979 | + " ('Fri', 300): {0: 0.96, 100: 0.01, 200: 0.01, 300: 0.01, 400: 0.01},\n", |
| 980 | + " ('Weekend', 0): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2},\n", |
| 981 | + " ('Weekend', 100): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2},\n", |
| 982 | + " ('Weekend', 200): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2},\n", |
| 983 | + " ('Weekend', 300): {0: 0.2, 100: 0.2, 200: 0.2, 300: 0.2, 400: 0.2}}" |
| 984 | + ] |
| 985 | + }, |
| 986 | + "execution_count": 90, |
| 987 | + "metadata": {}, |
| 988 | + "output_type": "execute_result" |
| 989 | + } |
| 990 | + ], |
| 991 | + "source": [ |
| 992 | + "policy" |
| 993 | + ] |
| 994 | + }, |
| 995 | + { |
| 996 | + "cell_type": "code", |
| 997 | + "execution_count": 95, |
| 998 | + "metadata": {}, |
| 999 | + "outputs": [ |
| 1000 | + { |
| 1001 | + "data": { |
| 1002 | + "text/plain": [ |
| 1003 | + "{('Mon', 0): {0: 2162.733333333329,\n", |
| 1004 | + " 100: 2468.4210526315796,\n", |
| 1005 | + " 200: 2668.7695190505888,\n", |
| 1006 | + " 300: 2739.300098231826,\n", |
| 1007 | + " 400: 2809.1632287569414},\n", |
| 1008 | + " ('Tue', 0): {0: 1539.1011235955057,\n", |
| 1009 | + " 100: 1857.630979498861,\n", |
| 1010 | + " 200: 2018.3222958057395,\n", |
| 1011 | + " 300: 2101.97486535009,\n", |
| 1012 | + " 400: 2181.249139237035},\n", |
| 1013 | + " ('Tue', 100): {0: 2243.7967115097176,\n", |
| 1014 | + " 100: 2410.7182940516295,\n", |
| 1015 | + " 200: 2537.853107344635,\n", |
| 1016 | + " 300: 2587.222441722628,\n", |
| 1017 | + " 400: 2170.4049844236765},\n", |
| 1018 | + " ('Tue', 200): {0: 2828.295819935689,\n", |
| 1019 | + " 100: 2953.6330631123433,\n", |
| 1020 | + " 200: 2996.437255166801,\n", |
| 1021 | + " 300: 2623.82297551789,\n", |
| 1022 | + " 400: 2224.710080285464},\n", |
| 1023 | + " ('Tue', 300): {0: 3383.880037488284,\n", |
| 1024 | + " 100: 3395.720002238628,\n", |
| 1025 | + " 200: 2939.4218134034168,\n", |
| 1026 | + " 300: 2572.2506393861877,\n", |
| 1027 | + " 400: 2162.3395149786},\n", |
| 1028 | + " ('Wed', 0): {0: 935.7142857142857,\n", |
| 1029 | + " 100: 1256.8720379146928,\n", |
| 1030 | + " 200: 1400.5025125628129,\n", |
| 1031 | + " 300: 1547.1040492055338,\n", |
| 1032 | + " 400: 1579.8683874265244},\n", |
| 1033 | + " ('Wed', 100): {0: 1639.7689768976904,\n", |
| 1034 | + " 100: 1868.1431005110733,\n", |
| 1035 | + " 200: 1908.107074569789,\n", |
| 1036 | + " 300: 1989.5285532259934,\n", |
| 1037 | + " 400: 1605.021520803444},\n", |
| 1038 | + " ('Wed', 200): {0: 2250.352733686064,\n", |
| 1039 | + " 100: 2341.068532900906,\n", |
| 1040 | + " 200: 2383.0059803588124,\n", |
| 1041 | + " 300: 1962.005277044855,\n", |
| 1042 | + " 400: 1573.4144222415298},\n", |
| 1043 | + " ('Wed', 300): {0: 2758.00389203214,\n", |
| 1044 | + " 100: 2778.022627490717,\n", |
| 1045 | + " 200: 2393.5081148564277,\n", |
| 1046 | + " 300: 1985.8374384236454,\n", |
| 1047 | + " 400: 1614.6220570012397},\n", |
| 1048 | + " ('Thu', 0): {0: 369.36619718309856,\n", |
| 1049 | + " 100: 684.2803030303028,\n", |
| 1050 | + " 200: 903.1539888682744,\n", |
| 1051 | + " 300: 972.1787871266652,\n", |
| 1052 | + " 400: 930.1247771836006},\n", |
| 1053 | + " ('Thu', 100): {0: 1084.478371501272,\n", |
| 1054 | + " 100: 1289.5073754522657,\n", |
| 1055 | + " 200: 1372.1298508969842,\n", |
| 1056 | + " 300: 1332.386447699365,\n", |
| 1057 | + " 400: 953.6523929471032},\n", |
| 1058 | + " ('Thu', 200): {0: 1677.668161434978,\n", |
| 1059 | + " 100: 1769.2753842946279,\n", |
| 1060 | + " 200: 1733.8299737072743,\n", |
| 1061 | + " 300: 1325.3393665158371,\n", |
| 1062 | + " 400: 919.6219621962197},\n", |
| 1063 | + " ('Thu', 300): {0: 2169.691663233083,\n", |
| 1064 | + " 100: 2166.585956416466,\n", |
| 1065 | + " 200: 1757.9545454545455,\n", |
| 1066 | + " 300: 1333.6569579288014,\n", |
| 1067 | + " 400: 953.3227848101266},\n", |
| 1068 | + " ('Fri', 0): {0: 0.0,\n", |
| 1069 | + " 100: 300.0,\n", |
| 1070 | + " 200: 388.81505831705283,\n", |
| 1071 | + " 300: 186.4516129032258,\n", |
| 1072 | + " 400: -142.74809160305333},\n", |
| 1073 | + " ('Fri', 100): {0: 700.0,\n", |
| 1074 | + " 100: 790.5049146968516,\n", |
| 1075 | + " 200: 607.3234524847425,\n", |
| 1076 | + " 300: 267.3796791443842,\n", |
| 1077 | + " 400: -110.91954022988506},\n", |
| 1078 | + " ('Fri', 200): {0: 1190.3311990960892,\n", |
| 1079 | + " 100: 988.7775551102206,\n", |
| 1080 | + " 200: 640.5092592592597,\n", |
| 1081 | + " 300: 267.4418604651163,\n", |
| 1082 | + " 400: -112.39263803680979},\n", |
| 1083 | + " ('Fri', 300): {0: 1399.4254760341432,\n", |
| 1084 | + " 100: 1152.7272727272725,\n", |
| 1085 | + " 200: 742.1875,\n", |
| 1086 | + " 300: 284.4827586206896,\n", |
| 1087 | + " 400: -120.0},\n", |
| 1088 | + " ('Weekend', 0): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 1089 | + " ('Weekend', 100): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 1090 | + " ('Weekend', 200): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 1091 | + " ('Weekend', 300): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0}}" |
| 1092 | + ] |
| 1093 | + }, |
| 1094 | + "execution_count": 95, |
| 1095 | + "metadata": {}, |
| 1096 | + "output_type": "execute_result" |
| 1097 | + } |
| 1098 | + ], |
| 1099 | + "source": [ |
| 1100 | + "Q" |
| 1101 | + ] |
| 1102 | + }, |
| 1103 | + { |
| 1104 | + "cell_type": "markdown", |
| 1105 | + "metadata": {}, |
| 1106 | + "source": [ |
| 1107 | + "## Off-policy Monte Carlo Control" |
| 1108 | + ] |
| 1109 | + }, |
| 1110 | + { |
| 1111 | + "cell_type": "code", |
| 1112 | + "execution_count": null, |
| 1113 | + "metadata": {}, |
| 1114 | + "outputs": [], |
| 1115 | + "source": [] |
| 1116 | + }, |
| 1117 | + { |
| 1118 | + "cell_type": "code", |
| 1119 | + "execution_count": 108, |
| 1120 | + "metadata": {}, |
| 1121 | + "outputs": [], |
| 1122 | + "source": [ |
| 1123 | + "def off_policy_mc(env, n_iter, eps, gamma):\n", |
| 1124 | + " np.random.seed(0)\n", |
| 1125 | + " states = env.state_space\n", |
| 1126 | + " actions = env.action_space\n", |
| 1127 | + " Q = {s: {a: 0 for a in actions} for s in states}\n", |
| 1128 | + " C = {s: {a: 0 for a in actions} for s in states}\n", |
| 1129 | + " target_policy = {}\n", |
| 1130 | + " behavior_policy = get_random_policy(states, \n", |
| 1131 | + " actions)\n", |
| 1132 | + " for i in range(n_iter):\n", |
| 1133 | + " if i % 10000 == 0:\n", |
| 1134 | + " print(\"Iteration:\", i)\n", |
| 1135 | + " trajectory = get_trajectory(env, \n", |
| 1136 | + " behavior_policy)\n", |
| 1137 | + " G = 0\n", |
| 1138 | + " W = 1\n", |
| 1139 | + " T = len(trajectory) - 1\n", |
| 1140 | + " for t, sar in enumerate(reversed(trajectory)):\n", |
| 1141 | + " s, a, r = sar\n", |
| 1142 | + " G = r + gamma * G\n", |
| 1143 | + " C[s][a] += W\n", |
| 1144 | + " Q[s][a] += (W/C[s][a]) * (G - Q[s][a])\n", |
| 1145 | + " a_best = max(Q[s].items(), \n", |
| 1146 | + " key=operator.itemgetter(1))[0]\n", |
| 1147 | + " target_policy[s] = a_best\n", |
| 1148 | + " behavior_policy[s] = get_eps_greedy(actions, \n", |
| 1149 | + " eps, \n", |
| 1150 | + " a_best)\n", |
| 1151 | + " if a != target_policy[s]:\n", |
| 1152 | + " break\n", |
| 1153 | + " W = W / behavior_policy[s][a]\n", |
| 1154 | + " target_policy = {s: target_policy[s] for s in states\n", |
| 1155 | + " if s in target_policy}\n", |
| 1156 | + " return target_policy, Q" |
| 1157 | + ] |
| 1158 | + }, |
| 1159 | + { |
| 1160 | + "cell_type": "code", |
| 1161 | + "execution_count": 109, |
| 1162 | + "metadata": {}, |
| 1163 | + "outputs": [ |
| 1164 | + { |
| 1165 | + "name": "stdout", |
| 1166 | + "output_type": "stream", |
| 1167 | + "text": [ |
| 1168 | + "Iteration: 0\n", |
| 1169 | + "Iteration: 10000\n", |
| 1170 | + "Iteration: 20000\n", |
| 1171 | + "Iteration: 30000\n", |
| 1172 | + "Iteration: 40000\n", |
| 1173 | + "Iteration: 50000\n", |
| 1174 | + "Iteration: 60000\n", |
| 1175 | + "Iteration: 70000\n", |
| 1176 | + "Iteration: 80000\n", |
| 1177 | + "Iteration: 90000\n", |
| 1178 | + "Iteration: 100000\n", |
| 1179 | + "Iteration: 110000\n", |
| 1180 | + "Iteration: 120000\n", |
| 1181 | + "Iteration: 130000\n", |
| 1182 | + "Iteration: 140000\n", |
| 1183 | + "Iteration: 150000\n", |
| 1184 | + "Iteration: 160000\n", |
| 1185 | + "Iteration: 170000\n", |
| 1186 | + "Iteration: 180000\n", |
| 1187 | + "Iteration: 190000\n", |
| 1188 | + "Iteration: 200000\n", |
| 1189 | + "Iteration: 210000\n", |
| 1190 | + "Iteration: 220000\n", |
| 1191 | + "Iteration: 230000\n", |
| 1192 | + "Iteration: 240000\n", |
| 1193 | + "Iteration: 250000\n", |
| 1194 | + "Iteration: 260000\n", |
| 1195 | + "Iteration: 270000\n", |
| 1196 | + "Iteration: 280000\n", |
| 1197 | + "Iteration: 290000\n" |
| 1198 | + ] |
| 1199 | + } |
| 1200 | + ], |
| 1201 | + "source": [ |
| 1202 | + "policy, Q = off_policy_mc(foodtruck, 300000, 0.05, 1)" |
| 1203 | + ] |
| 1204 | + }, |
| 1205 | + { |
| 1206 | + "cell_type": "code", |
| 1207 | + "execution_count": 110, |
| 1208 | + "metadata": {}, |
| 1209 | + "outputs": [ |
| 1210 | + { |
| 1211 | + "data": { |
| 1212 | + "text/plain": [ |
| 1213 | + "{('Mon', 0): 400,\n", |
| 1214 | + " ('Tue', 0): 400,\n", |
| 1215 | + " ('Tue', 100): 300,\n", |
| 1216 | + " ('Tue', 200): 200,\n", |
| 1217 | + " ('Tue', 300): 100,\n", |
| 1218 | + " ('Wed', 0): 400,\n", |
| 1219 | + " ('Wed', 100): 300,\n", |
| 1220 | + " ('Wed', 200): 200,\n", |
| 1221 | + " ('Wed', 300): 100,\n", |
| 1222 | + " ('Thu', 0): 300,\n", |
| 1223 | + " ('Thu', 100): 200,\n", |
| 1224 | + " ('Thu', 200): 100,\n", |
| 1225 | + " ('Thu', 300): 0,\n", |
| 1226 | + " ('Fri', 0): 200,\n", |
| 1227 | + " ('Fri', 100): 100,\n", |
| 1228 | + " ('Fri', 200): 0,\n", |
| 1229 | + " ('Fri', 300): 0}" |
| 1230 | + ] |
| 1231 | + }, |
| 1232 | + "execution_count": 110, |
| 1233 | + "metadata": {}, |
| 1234 | + "output_type": "execute_result" |
| 1235 | + } |
| 1236 | + ], |
| 1237 | + "source": [ |
| 1238 | + "policy" |
| 1239 | + ] |
| 1240 | + }, |
| 1241 | + { |
| 1242 | + "cell_type": "code", |
| 1243 | + "execution_count": 111, |
| 1244 | + "metadata": {}, |
| 1245 | + "outputs": [ |
| 1246 | + { |
| 1247 | + "data": { |
| 1248 | + "text/plain": [ |
| 1249 | + "{('Mon', 0): {0: 2232.674050632915,\n", |
| 1250 | + " 100: 2539.364696421396,\n", |
| 1251 | + " 200: 2725.681570338065,\n", |
| 1252 | + " 300: 2822.8136882129284,\n", |
| 1253 | + " 400: 2878.458190025779},\n", |
| 1254 | + " ('Tue', 0): {0: 1594.8051948051952,\n", |
| 1255 | + " 100: 1928.976034858388,\n", |
| 1256 | + " 200: 2067.4576271186465,\n", |
| 1257 | + " 300: 2207.8512396694205,\n", |
| 1258 | + " 400: 2239.8886329583893},\n", |
| 1259 | + " ('Tue', 100): {0: 2318.9435336976317,\n", |
| 1260 | + " 100: 2536.8012422360302,\n", |
| 1261 | + " 200: 2549.486301369862,\n", |
| 1262 | + " 300: 2650.193090274893,\n", |
| 1263 | + " 400: 2256.120527306967},\n", |
| 1264 | + " ('Tue', 200): {0: 2922.175290390706,\n", |
| 1265 | + " 100: 3012.8990770161868,\n", |
| 1266 | + " 200: 3052.4769607403373,\n", |
| 1267 | + " 300: 2689.515219842163,\n", |
| 1268 | + " 400: 2293.305439330548},\n", |
| 1269 | + " ('Tue', 300): {0: 3420.032031538755,\n", |
| 1270 | + " 100: 3453.749726573689,\n", |
| 1271 | + " 200: 3014.1210374639763,\n", |
| 1272 | + " 300: 2635.802469135803,\n", |
| 1273 | + " 400: 2233.3333333333344},\n", |
| 1274 | + " ('Wed', 0): {0: 927.9702970297026,\n", |
| 1275 | + " 100: 1303.1026252983302,\n", |
| 1276 | + " 200: 1428.831168831168,\n", |
| 1277 | + " 300: 1566.1498708010329,\n", |
| 1278 | + " 400: 1616.5133331502423},\n", |
| 1279 | + " ('Wed', 100): {0: 1683.8652482269495,\n", |
| 1280 | + " 100: 1896.0360360360366,\n", |
| 1281 | + " 200: 1976.8450184501858,\n", |
| 1282 | + " 300: 2024.3386976631361,\n", |
| 1283 | + " 400: 1650.87440381558},\n", |
| 1284 | + " ('Wed', 200): {0: 2277.8664007976076,\n", |
| 1285 | + " 100: 2405.7504873294333,\n", |
| 1286 | + " 200: 2419.006699098848,\n", |
| 1287 | + " 300: 2000.3857280617174,\n", |
| 1288 | + " 400: 1608.8068181818178},\n", |
| 1289 | + " ('Wed', 300): {0: 2779.4180573384715,\n", |
| 1290 | + " 100: 2818.4754229486366,\n", |
| 1291 | + " 200: 2422.7878787878817,\n", |
| 1292 | + " 300: 2017.7989130434773,\n", |
| 1293 | + " 400: 1660.6602475928496},\n", |
| 1294 | + " ('Thu', 0): {0: 369.164265129683,\n", |
| 1295 | + " 100: 684.9275362318838,\n", |
| 1296 | + " 200: 912.9056047197645,\n", |
| 1297 | + " 300: 988.2722582352171,\n", |
| 1298 | + " 400: 926.3157894736842},\n", |
| 1299 | + " ('Thu', 100): {0: 1090.7738095238096,\n", |
| 1300 | + " 100: 1329.5566502463064,\n", |
| 1301 | + " 200: 1392.0507055220148,\n", |
| 1302 | + " 300: 1373.6577181208052,\n", |
| 1303 | + " 400: 961.7241379310348},\n", |
| 1304 | + " ('Thu', 200): {0: 1699.1087344028504,\n", |
| 1305 | + " 100: 1789.1752957897363,\n", |
| 1306 | + " 200: 1760.222222222221,\n", |
| 1307 | + " 300: 1342.7149321266952,\n", |
| 1308 | + " 400: 965.2557319223995},\n", |
| 1309 | + " ('Thu', 300): {0: 2190.6271182185546,\n", |
| 1310 | + " 100: 2176.451612903226,\n", |
| 1311 | + " 200: 1780.9290953545235,\n", |
| 1312 | + " 300: 1360.7017543859658,\n", |
| 1313 | + " 400: 964.203233256352},\n", |
| 1314 | + " ('Fri', 0): {0: 0.0,\n", |
| 1315 | + " 100: 300.0,\n", |
| 1316 | + " 200: 388.8413403310466,\n", |
| 1317 | + " 300: 189.6405919661735,\n", |
| 1318 | + " 400: -146.61016949152557},\n", |
| 1319 | + " ('Fri', 100): {0: 700.0,\n", |
| 1320 | + " 100: 790.0458861880747,\n", |
| 1321 | + " 200: 608.920985556499,\n", |
| 1322 | + " 300: 265.5866900175128,\n", |
| 1323 | + " 400: -103.34967320261451},\n", |
| 1324 | + " ('Fri', 200): {0: 1190.388245916431,\n", |
| 1325 | + " 100: 1009.6551724137929,\n", |
| 1326 | + " 200: 651.612903225807,\n", |
| 1327 | + " 300: 266.99669966996686,\n", |
| 1328 | + " 400: -116.64641555285537},\n", |
| 1329 | + " ('Fri', 300): {0: 1404.084014002334,\n", |
| 1330 | + " 100: 1116.6666666666667,\n", |
| 1331 | + " 200: 702.9411764705883,\n", |
| 1332 | + " 300: 282.3529411764706,\n", |
| 1333 | + " 400: -175.86206896551724},\n", |
| 1334 | + " ('Weekend', 0): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 1335 | + " ('Weekend', 100): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 1336 | + " ('Weekend', 200): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 1337 | + " ('Weekend', 300): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0}}" |
| 1338 | + ] |
| 1339 | + }, |
| 1340 | + "execution_count": 111, |
| 1341 | + "metadata": {}, |
| 1342 | + "output_type": "execute_result" |
| 1343 | + } |
| 1344 | + ], |
| 1345 | + "source": [ |
| 1346 | + "Q" |
| 1347 | + ] |
| 1348 | + }, |
| 1349 | + { |
| 1350 | + "cell_type": "code", |
| 1351 | + "execution_count": null, |
| 1352 | + "metadata": {}, |
| 1353 | + "outputs": [], |
| 1354 | + "source": [] |
| 1355 | + }, |
| 1356 | + { |
| 1357 | + "cell_type": "code", |
| 1358 | + "execution_count": null, |
| 1359 | + "metadata": {}, |
| 1360 | + "outputs": [], |
| 1361 | + "source": [] |
| 1362 | + }, |
| 1363 | + { |
| 1364 | + "cell_type": "markdown", |
| 1365 | + "metadata": {}, |
| 1366 | + "source": [ |
| 1367 | + "# TD Learning" |
| 1368 | + ] |
| 1369 | + }, |
| 1370 | + { |
| 1371 | + "cell_type": "markdown", |
| 1372 | + "metadata": {}, |
| 1373 | + "source": [ |
| 1374 | + "## TD Prediction" |
| 1375 | + ] |
| 1376 | + }, |
| 1377 | + { |
| 1378 | + "cell_type": "code", |
| 1379 | + "execution_count": 116, |
| 1380 | + "metadata": {}, |
| 1381 | + "outputs": [], |
| 1382 | + "source": [ |
| 1383 | + "def one_step_td_prediction(env, policy, gamma, alpha, n_iter):\n", |
| 1384 | + " np.random.seed(0)\n", |
| 1385 | + " states = env.state_space\n", |
| 1386 | + " v = {s: 0 for s in states}\n", |
| 1387 | + " s = env.reset()\n", |
| 1388 | + " for i in range(n_iter):\n", |
| 1389 | + " a = choose_action(s, policy)\n", |
| 1390 | + " s_next, reward, done, info = env.step(a)\n", |
| 1391 | + " v[s] += alpha * (reward + gamma * v[s_next] - v[s])\n", |
| 1392 | + " if done:\n", |
| 1393 | + " s = env.reset()\n", |
| 1394 | + " else:\n", |
| 1395 | + " s = s_next\n", |
| 1396 | + " return v" |
| 1397 | + ] |
| 1398 | + }, |
| 1399 | + { |
| 1400 | + "cell_type": "code", |
| 1401 | + "execution_count": null, |
| 1402 | + "metadata": {}, |
| 1403 | + "outputs": [], |
| 1404 | + "source": [] |
| 1405 | + }, |
| 1406 | + { |
| 1407 | + "cell_type": "code", |
| 1408 | + "execution_count": 117, |
| 1409 | + "metadata": {}, |
| 1410 | + "outputs": [ |
| 1411 | + { |
| 1412 | + "data": { |
| 1413 | + "text/plain": [ |
| 1414 | + "{('Mon', 0): 2506.576417395407,\n", |
| 1415 | + " ('Tue', 0): 1956.077876400167,\n", |
| 1416 | + " ('Tue', 100): 2368.7400039407535,\n", |
| 1417 | + " ('Tue', 200): 2767.5069659225423,\n", |
| 1418 | + " ('Tue', 300): 0,\n", |
| 1419 | + " ('Wed', 0): 1413.0055559001296,\n", |
| 1420 | + " ('Wed', 100): 1813.546186490315,\n", |
| 1421 | + " ('Wed', 200): 2200.8873259700867,\n", |
| 1422 | + " ('Wed', 300): 0,\n", |
| 1423 | + " ('Thu', 0): 828.2915189850011,\n", |
| 1424 | + " ('Thu', 100): 1280.424626614422,\n", |
| 1425 | + " ('Thu', 200): 1675.8661846955831,\n", |
| 1426 | + " ('Thu', 300): 0,\n", |
| 1427 | + " ('Fri', 0): 345.52991944823583,\n", |
| 1428 | + " ('Fri', 100): 677.4358179389413,\n", |
| 1429 | + " ('Fri', 200): 1094.8263154150825,\n", |
| 1430 | + " ('Fri', 300): 0,\n", |
| 1431 | + " ('Weekend', 0): 0,\n", |
| 1432 | + " ('Weekend', 100): 0,\n", |
| 1433 | + " ('Weekend', 200): 0,\n", |
| 1434 | + " ('Weekend', 300): 0}" |
| 1435 | + ] |
| 1436 | + }, |
| 1437 | + "execution_count": 117, |
| 1438 | + "metadata": {}, |
| 1439 | + "output_type": "execute_result" |
| 1440 | + } |
| 1441 | + ], |
| 1442 | + "source": [ |
| 1443 | + "policy = base_policy(foodtruck.state_space)\n", |
| 1444 | + "v = one_step_td_prediction(foodtruck, policy, 1, 0.01, 100000)\n", |
| 1445 | + "v" |
| 1446 | + ] |
| 1447 | + }, |
| 1448 | + { |
| 1449 | + "cell_type": "code", |
| 1450 | + "execution_count": null, |
| 1451 | + "metadata": {}, |
| 1452 | + "outputs": [], |
| 1453 | + "source": [ |
| 1454 | + "print({s: np.round(v[s]) for s in v})" |
| 1455 | + ] |
| 1456 | + }, |
| 1457 | + { |
| 1458 | + "cell_type": "markdown", |
| 1459 | + "metadata": {}, |
| 1460 | + "source": [ |
| 1461 | + "True values\n", |
| 1462 | + "{('Mon', 0): 2515.0,\n", |
| 1463 | + " ('Tue', 0): 1960.0,\n", |
| 1464 | + " ('Tue', 100): 2360.0,\n", |
| 1465 | + " ('Tue', 200): 2760.0,\n", |
| 1466 | + " ('Tue', 300): 3205.0,\n", |
| 1467 | + " ('Wed', 0): 1405.0,\n", |
| 1468 | + " ('Wed', 100): 1805.0,\n", |
| 1469 | + " ('Wed', 200): 2205.0,\n", |
| 1470 | + " ('Wed', 300): 2650.0,\n", |
| 1471 | + " ('Thu', 0): 850.0000000000001,\n", |
| 1472 | + " ('Thu', 100): 1250.0,\n", |
| 1473 | + " ('Thu', 200): 1650.0,\n", |
| 1474 | + " ('Thu', 300): 2095.0,\n", |
| 1475 | + " ('Fri', 0): 295.00000000000006,\n", |
| 1476 | + " ('Fri', 100): 695.0000000000001,\n", |
| 1477 | + " ('Fri', 200): 1095.0,\n", |
| 1478 | + " ('Fri', 300): 1400.0,\n", |
| 1479 | + " ('Weekend', 0): 0,\n", |
| 1480 | + " ('Weekend', 100): 0,\n", |
| 1481 | + " ('Weekend', 200): 0,\n", |
| 1482 | + " ('Weekend', 300): 0}" |
| 1483 | + ] |
| 1484 | + }, |
| 1485 | + { |
| 1486 | + "cell_type": "code", |
| 1487 | + "execution_count": 118, |
| 1488 | + "metadata": {}, |
| 1489 | + "outputs": [], |
| 1490 | + "source": [ |
| 1491 | + "def sarsa(env, gamma, eps, alpha, n_iter):\n", |
| 1492 | + " np.random.seed(0)\n", |
| 1493 | + " states = env.state_space\n", |
| 1494 | + " actions = env.action_space\n", |
| 1495 | + " Q = {s: {a: 0 for a in actions} for s in states}\n", |
| 1496 | + " policy = get_random_policy(states, actions)\n", |
| 1497 | + " s = env.reset()\n", |
| 1498 | + " a = choose_action(s, policy)\n", |
| 1499 | + " for i in range(n_iter):\n", |
| 1500 | + " if i % 100000 == 0:\n", |
| 1501 | + " print(\"Iteration:\", i)\n", |
| 1502 | + " s_next, reward, done, info = env.step(a)\n", |
| 1503 | + " a_best = max(Q[s_next].items(), \n", |
| 1504 | + " key=operator.itemgetter(1))[0]\n", |
| 1505 | + " policy[s_next] = get_eps_greedy(actions, eps, a_best)\n", |
| 1506 | + " a_next = choose_action(s_next, policy)\n", |
| 1507 | + " Q[s][a] += alpha * (reward \n", |
| 1508 | + " + gamma * Q[s_next][a_next] \n", |
| 1509 | + " - Q[s][a])\n", |
| 1510 | + " if done:\n", |
| 1511 | + " s = env.reset()\n", |
| 1512 | + " a_best = max(Q[s].items(), \n", |
| 1513 | + " key=operator.itemgetter(1))[0]\n", |
| 1514 | + " policy[s] = get_eps_greedy(actions, eps, a_best)\n", |
| 1515 | + " a = choose_action(s, policy)\n", |
| 1516 | + " else:\n", |
| 1517 | + " s = s_next\n", |
| 1518 | + " a = a_next\n", |
| 1519 | + " return policy, Q" |
| 1520 | + ] |
| 1521 | + }, |
| 1522 | + { |
| 1523 | + "cell_type": "code", |
| 1524 | + "execution_count": 119, |
| 1525 | + "metadata": {}, |
| 1526 | + "outputs": [ |
| 1527 | + { |
| 1528 | + "name": "stdout", |
| 1529 | + "output_type": "stream", |
| 1530 | + "text": [ |
| 1531 | + "Iteration: 0\n", |
| 1532 | + "Iteration: 100000\n", |
| 1533 | + "Iteration: 200000\n", |
| 1534 | + "Iteration: 300000\n", |
| 1535 | + "Iteration: 400000\n", |
| 1536 | + "Iteration: 500000\n", |
| 1537 | + "Iteration: 600000\n", |
| 1538 | + "Iteration: 700000\n", |
| 1539 | + "Iteration: 800000\n", |
| 1540 | + "Iteration: 900000\n" |
| 1541 | + ] |
| 1542 | + } |
| 1543 | + ], |
| 1544 | + "source": [ |
| 1545 | + "policy, Q = sarsa(foodtruck, 1, 0.1, 0.01, 1000000)" |
| 1546 | + ] |
| 1547 | + }, |
| 1548 | + { |
| 1549 | + "cell_type": "code", |
| 1550 | + "execution_count": 120, |
| 1551 | + "metadata": {}, |
| 1552 | + "outputs": [ |
| 1553 | + { |
| 1554 | + "data": { |
| 1555 | + "text/plain": [ |
| 1556 | + "{('Mon', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n", |
| 1557 | + " ('Tue', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n", |
| 1558 | + " ('Tue', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n", |
| 1559 | + " ('Tue', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1560 | + " ('Tue', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1561 | + " ('Wed', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n", |
| 1562 | + " ('Wed', 100): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n", |
| 1563 | + " ('Wed', 200): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n", |
| 1564 | + " ('Wed', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1565 | + " ('Thu', 0): {0: 0.02, 100: 0.02, 200: 0.02, 300: 0.92, 400: 0.02},\n", |
| 1566 | + " ('Thu', 100): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n", |
| 1567 | + " ('Thu', 200): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1568 | + " ('Thu', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1569 | + " ('Fri', 0): {0: 0.02, 100: 0.02, 200: 0.92, 300: 0.02, 400: 0.02},\n", |
| 1570 | + " ('Fri', 100): {0: 0.02, 100: 0.92, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1571 | + " ('Fri', 200): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1572 | + " ('Fri', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1573 | + " ('Weekend', 0): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1574 | + " ('Weekend', 100): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1575 | + " ('Weekend', 200): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02},\n", |
| 1576 | + " ('Weekend', 300): {0: 0.92, 100: 0.02, 200: 0.02, 300: 0.02, 400: 0.02}}" |
| 1577 | + ] |
| 1578 | + }, |
| 1579 | + "execution_count": 120, |
| 1580 | + "metadata": {}, |
| 1581 | + "output_type": "execute_result" |
| 1582 | + } |
| 1583 | + ], |
| 1584 | + "source": [ |
| 1585 | + "policy" |
| 1586 | + ] |
| 1587 | + }, |
| 1588 | + { |
| 1589 | + "cell_type": "code", |
| 1590 | + "execution_count": 121, |
| 1591 | + "metadata": {}, |
| 1592 | + "outputs": [ |
| 1593 | + { |
| 1594 | + "data": { |
| 1595 | + "text/plain": [ |
| 1596 | + "{0: 2099.8661156763687,\n", |
| 1597 | + " 100: 2399.8190742726747,\n", |
| 1598 | + " 200: 2604.6629056622382,\n", |
| 1599 | + " 300: 2670.098987213351,\n", |
| 1600 | + " 400: 2632.8387133517112}" |
| 1601 | + ] |
| 1602 | + }, |
| 1603 | + "execution_count": 121, |
| 1604 | + "metadata": {}, |
| 1605 | + "output_type": "execute_result" |
| 1606 | + } |
| 1607 | + ], |
| 1608 | + "source": [ |
| 1609 | + "Q[('Mon', 0)]" |
| 1610 | + ] |
| 1611 | + }, |
| 1612 | + { |
| 1613 | + "cell_type": "code", |
| 1614 | + "execution_count": null, |
| 1615 | + "metadata": {}, |
| 1616 | + "outputs": [], |
| 1617 | + "source": [] |
| 1618 | + }, |
| 1619 | + { |
| 1620 | + "cell_type": "code", |
| 1621 | + "execution_count": null, |
| 1622 | + "metadata": {}, |
| 1623 | + "outputs": [], |
| 1624 | + "source": [] |
| 1625 | + }, |
| 1626 | + { |
| 1627 | + "cell_type": "code", |
| 1628 | + "execution_count": 122, |
| 1629 | + "metadata": {}, |
| 1630 | + "outputs": [], |
| 1631 | + "source": [ |
| 1632 | + "def q_learning(env, gamma, eps, alpha, n_iter):\n", |
| 1633 | + " np.random.seed(0)\n", |
| 1634 | + " states = env.state_space\n", |
| 1635 | + " actions = env.action_space\n", |
| 1636 | + " Q = {s: {a: 0 for a in actions} for s in states}\n", |
| 1637 | + " policy = get_random_policy(states, actions)\n", |
| 1638 | + " s = env.reset()\n", |
| 1639 | + " for i in range(n_iter):\n", |
| 1640 | + " if i % 100000 == 0:\n", |
| 1641 | + " print(\"Iteration:\", i)\n", |
| 1642 | + " a_best = max(Q[s].items(), \n", |
| 1643 | + " key=operator.itemgetter(1))[0]\n", |
| 1644 | + " policy[s] = get_eps_greedy(actions, eps, a_best)\n", |
| 1645 | + " a = choose_action(s, policy)\n", |
| 1646 | + " s_next, reward, done, info = env.step(a)\n", |
| 1647 | + " Q[s][a] += alpha * (reward \n", |
| 1648 | + " + gamma * max(Q[s_next].values()) \n", |
| 1649 | + " - Q[s][a])\n", |
| 1650 | + " if done:\n", |
| 1651 | + " s = env.reset()\n", |
| 1652 | + " else:\n", |
| 1653 | + " s = s_next\n", |
| 1654 | + " policy = {s: {max(policy[s].items(), \n", |
| 1655 | + " key=operator.itemgetter(1))[0]: 1}\n", |
| 1656 | + " for s in states}\n", |
| 1657 | + " return policy, Q" |
| 1658 | + ] |
| 1659 | + }, |
| 1660 | + { |
| 1661 | + "cell_type": "code", |
| 1662 | + "execution_count": 123, |
| 1663 | + "metadata": {}, |
| 1664 | + "outputs": [ |
| 1665 | + { |
| 1666 | + "name": "stdout", |
| 1667 | + "output_type": "stream", |
| 1668 | + "text": [ |
| 1669 | + "Iteration: 0\n", |
| 1670 | + "Iteration: 100000\n", |
| 1671 | + "Iteration: 200000\n", |
| 1672 | + "Iteration: 300000\n", |
| 1673 | + "Iteration: 400000\n", |
| 1674 | + "Iteration: 500000\n", |
| 1675 | + "Iteration: 600000\n", |
| 1676 | + "Iteration: 700000\n", |
| 1677 | + "Iteration: 800000\n", |
| 1678 | + "Iteration: 900000\n" |
| 1679 | + ] |
| 1680 | + }, |
| 1681 | + { |
| 1682 | + "data": { |
| 1683 | + "text/plain": [ |
| 1684 | + "{('Mon', 0): {400: 1},\n", |
| 1685 | + " ('Tue', 0): {400: 1},\n", |
| 1686 | + " ('Tue', 100): {300: 1},\n", |
| 1687 | + " ('Tue', 200): {200: 1},\n", |
| 1688 | + " ('Tue', 300): {100: 1},\n", |
| 1689 | + " ('Wed', 0): {400: 1},\n", |
| 1690 | + " ('Wed', 100): {300: 1},\n", |
| 1691 | + " ('Wed', 200): {200: 1},\n", |
| 1692 | + " ('Wed', 300): {100: 1},\n", |
| 1693 | + " ('Thu', 0): {300: 1},\n", |
| 1694 | + " ('Thu', 100): {200: 1},\n", |
| 1695 | + " ('Thu', 200): {100: 1},\n", |
| 1696 | + " ('Thu', 300): {0: 1},\n", |
| 1697 | + " ('Fri', 0): {200: 1},\n", |
| 1698 | + " ('Fri', 100): {100: 1},\n", |
| 1699 | + " ('Fri', 200): {0: 1},\n", |
| 1700 | + " ('Fri', 300): {0: 1},\n", |
| 1701 | + " ('Weekend', 0): {0: 1},\n", |
| 1702 | + " ('Weekend', 100): {0: 1},\n", |
| 1703 | + " ('Weekend', 200): {0: 1},\n", |
| 1704 | + " ('Weekend', 300): {0: 1}}" |
| 1705 | + ] |
| 1706 | + }, |
| 1707 | + "execution_count": 123, |
| 1708 | + "metadata": {}, |
| 1709 | + "output_type": "execute_result" |
| 1710 | + } |
| 1711 | + ], |
| 1712 | + "source": [ |
| 1713 | + "policy, Q = q_learning(foodtruck, 1, 0.1, 0.01, 1000000)\n", |
| 1714 | + "policy" |
| 1715 | + ] |
| 1716 | + }, |
| 1717 | + { |
| 1718 | + "cell_type": "code", |
| 1719 | + "execution_count": 124, |
| 1720 | + "metadata": {}, |
| 1721 | + "outputs": [ |
| 1722 | + { |
| 1723 | + "name": "stdout", |
| 1724 | + "output_type": "stream", |
| 1725 | + "text": [ |
| 1726 | + "Iteration: 0\n", |
| 1727 | + "Iteration: 100000\n", |
| 1728 | + "Iteration: 200000\n", |
| 1729 | + "Iteration: 300000\n", |
| 1730 | + "Iteration: 400000\n", |
| 1731 | + "Iteration: 500000\n", |
| 1732 | + "Iteration: 600000\n", |
| 1733 | + "Iteration: 700000\n", |
| 1734 | + "Iteration: 800000\n", |
| 1735 | + "Iteration: 900000\n", |
| 1736 | + "Iteration: 1000000\n", |
| 1737 | + "Iteration: 1100000\n", |
| 1738 | + "Iteration: 1200000\n", |
| 1739 | + "Iteration: 1300000\n", |
| 1740 | + "Iteration: 1400000\n", |
| 1741 | + "Iteration: 1500000\n", |
| 1742 | + "Iteration: 1600000\n", |
| 1743 | + "Iteration: 1700000\n", |
| 1744 | + "Iteration: 1800000\n", |
| 1745 | + "Iteration: 1900000\n", |
| 1746 | + "Iteration: 2000000\n", |
| 1747 | + "Iteration: 2100000\n", |
| 1748 | + "Iteration: 2200000\n", |
| 1749 | + "Iteration: 2300000\n", |
| 1750 | + "Iteration: 2400000\n", |
| 1751 | + "Iteration: 2500000\n", |
| 1752 | + "Iteration: 2600000\n", |
| 1753 | + "Iteration: 2700000\n", |
| 1754 | + "Iteration: 2800000\n", |
| 1755 | + "Iteration: 2900000\n", |
| 1756 | + "Iteration: 3000000\n", |
| 1757 | + "Iteration: 3100000\n", |
| 1758 | + "Iteration: 3200000\n", |
| 1759 | + "Iteration: 3300000\n", |
| 1760 | + "Iteration: 3400000\n", |
| 1761 | + "Iteration: 3500000\n", |
| 1762 | + "Iteration: 3600000\n", |
| 1763 | + "Iteration: 3700000\n", |
| 1764 | + "Iteration: 3800000\n", |
| 1765 | + "Iteration: 3900000\n", |
| 1766 | + "Iteration: 4000000\n", |
| 1767 | + "Iteration: 4100000\n", |
| 1768 | + "Iteration: 4200000\n", |
| 1769 | + "Iteration: 4300000\n", |
| 1770 | + "Iteration: 4400000\n", |
| 1771 | + "Iteration: 4500000\n", |
| 1772 | + "Iteration: 4600000\n", |
| 1773 | + "Iteration: 4700000\n", |
| 1774 | + "Iteration: 4800000\n", |
| 1775 | + "Iteration: 4900000\n", |
| 1776 | + "Iteration: 5000000\n", |
| 1777 | + "Iteration: 5100000\n", |
| 1778 | + "Iteration: 5200000\n", |
| 1779 | + "Iteration: 5300000\n", |
| 1780 | + "Iteration: 5400000\n", |
| 1781 | + "Iteration: 5500000\n", |
| 1782 | + "Iteration: 5600000\n", |
| 1783 | + "Iteration: 5700000\n", |
| 1784 | + "Iteration: 5800000\n", |
| 1785 | + "Iteration: 5900000\n", |
| 1786 | + "Iteration: 6000000\n", |
| 1787 | + "Iteration: 6100000\n", |
| 1788 | + "Iteration: 6200000\n", |
| 1789 | + "Iteration: 6300000\n", |
| 1790 | + "Iteration: 6400000\n", |
| 1791 | + "Iteration: 6500000\n", |
| 1792 | + "Iteration: 6600000\n", |
| 1793 | + "Iteration: 6700000\n", |
| 1794 | + "Iteration: 6800000\n", |
| 1795 | + "Iteration: 6900000\n", |
| 1796 | + "Iteration: 7000000\n", |
| 1797 | + "Iteration: 7100000\n", |
| 1798 | + "Iteration: 7200000\n", |
| 1799 | + "Iteration: 7300000\n", |
| 1800 | + "Iteration: 7400000\n", |
| 1801 | + "Iteration: 7500000\n", |
| 1802 | + "Iteration: 7600000\n", |
| 1803 | + "Iteration: 7700000\n", |
| 1804 | + "Iteration: 7800000\n", |
| 1805 | + "Iteration: 7900000\n", |
| 1806 | + "Iteration: 8000000\n", |
| 1807 | + "Iteration: 8100000\n", |
| 1808 | + "Iteration: 8200000\n", |
| 1809 | + "Iteration: 8300000\n", |
| 1810 | + "Iteration: 8400000\n", |
| 1811 | + "Iteration: 8500000\n", |
| 1812 | + "Iteration: 8600000\n", |
| 1813 | + "Iteration: 8700000\n", |
| 1814 | + "Iteration: 8800000\n", |
| 1815 | + "Iteration: 8900000\n", |
| 1816 | + "Iteration: 9000000\n", |
| 1817 | + "Iteration: 9100000\n", |
| 1818 | + "Iteration: 9200000\n", |
| 1819 | + "Iteration: 9300000\n", |
| 1820 | + "Iteration: 9400000\n", |
| 1821 | + "Iteration: 9500000\n", |
| 1822 | + "Iteration: 9600000\n", |
| 1823 | + "Iteration: 9700000\n", |
| 1824 | + "Iteration: 9800000\n", |
| 1825 | + "Iteration: 9900000\n", |
| 1826 | + "Iteration: 10000000\n", |
| 1827 | + "Iteration: 10100000\n", |
| 1828 | + "Iteration: 10200000\n", |
| 1829 | + "Iteration: 10300000\n", |
| 1830 | + "Iteration: 10400000\n", |
| 1831 | + "Iteration: 10500000\n", |
| 1832 | + "Iteration: 10600000\n", |
| 1833 | + "Iteration: 10700000\n", |
| 1834 | + "Iteration: 10800000\n", |
| 1835 | + "Iteration: 10900000\n", |
| 1836 | + "Iteration: 11000000\n", |
| 1837 | + "Iteration: 11100000\n", |
| 1838 | + "Iteration: 11200000\n", |
| 1839 | + "Iteration: 11300000\n", |
| 1840 | + "Iteration: 11400000\n", |
| 1841 | + "Iteration: 11500000\n", |
| 1842 | + "Iteration: 11600000\n", |
| 1843 | + "Iteration: 11700000\n", |
| 1844 | + "Iteration: 11800000\n", |
| 1845 | + "Iteration: 11900000\n", |
| 1846 | + "Iteration: 12000000\n", |
| 1847 | + "Iteration: 12100000\n", |
| 1848 | + "Iteration: 12200000\n", |
| 1849 | + "Iteration: 12300000\n", |
| 1850 | + "Iteration: 12400000\n", |
| 1851 | + "Iteration: 12500000\n", |
| 1852 | + "Iteration: 12600000\n", |
| 1853 | + "Iteration: 12700000\n", |
| 1854 | + "Iteration: 12800000\n", |
| 1855 | + "Iteration: 12900000\n", |
| 1856 | + "Iteration: 13000000\n", |
| 1857 | + "Iteration: 13100000\n", |
| 1858 | + "Iteration: 13200000\n", |
| 1859 | + "Iteration: 13300000\n", |
| 1860 | + "Iteration: 13400000\n", |
| 1861 | + "Iteration: 13500000\n", |
| 1862 | + "Iteration: 13600000\n", |
| 1863 | + "Iteration: 13700000\n", |
| 1864 | + "Iteration: 13800000\n", |
| 1865 | + "Iteration: 13900000\n", |
| 1866 | + "Iteration: 14000000\n", |
| 1867 | + "Iteration: 14100000\n", |
| 1868 | + "Iteration: 14200000\n", |
| 1869 | + "Iteration: 14300000\n", |
| 1870 | + "Iteration: 14400000\n", |
| 1871 | + "Iteration: 14500000\n", |
| 1872 | + "Iteration: 14600000\n", |
| 1873 | + "Iteration: 14700000\n", |
| 1874 | + "Iteration: 14800000\n", |
| 1875 | + "Iteration: 14900000\n", |
| 1876 | + "Iteration: 15000000\n", |
| 1877 | + "Iteration: 15100000\n", |
| 1878 | + "Iteration: 15200000\n", |
| 1879 | + "Iteration: 15300000\n", |
| 1880 | + "Iteration: 15400000\n", |
| 1881 | + "Iteration: 15500000\n", |
| 1882 | + "Iteration: 15600000\n", |
| 1883 | + "Iteration: 15700000\n", |
| 1884 | + "Iteration: 15800000\n", |
| 1885 | + "Iteration: 15900000\n", |
| 1886 | + "Iteration: 16000000\n", |
| 1887 | + "Iteration: 16100000\n", |
| 1888 | + "Iteration: 16200000\n", |
| 1889 | + "Iteration: 16300000\n", |
| 1890 | + "Iteration: 16400000\n", |
| 1891 | + "Iteration: 16500000\n", |
| 1892 | + "Iteration: 16600000\n", |
| 1893 | + "Iteration: 16700000\n", |
| 1894 | + "Iteration: 16800000\n", |
| 1895 | + "Iteration: 16900000\n", |
| 1896 | + "Iteration: 17000000\n", |
| 1897 | + "Iteration: 17100000\n", |
| 1898 | + "Iteration: 17200000\n", |
| 1899 | + "Iteration: 17300000\n", |
| 1900 | + "Iteration: 17400000\n", |
| 1901 | + "Iteration: 17500000\n", |
| 1902 | + "Iteration: 17600000\n", |
| 1903 | + "Iteration: 17700000\n", |
| 1904 | + "Iteration: 17800000\n", |
| 1905 | + "Iteration: 17900000\n", |
| 1906 | + "Iteration: 18000000\n", |
| 1907 | + "Iteration: 18100000\n", |
| 1908 | + "Iteration: 18200000\n", |
| 1909 | + "Iteration: 18300000\n", |
| 1910 | + "Iteration: 18400000\n", |
| 1911 | + "Iteration: 18500000\n", |
| 1912 | + "Iteration: 18600000\n", |
| 1913 | + "Iteration: 18700000\n", |
| 1914 | + "Iteration: 18800000\n", |
| 1915 | + "Iteration: 18900000\n", |
| 1916 | + "Iteration: 19000000\n", |
| 1917 | + "Iteration: 19100000\n", |
| 1918 | + "Iteration: 19200000\n", |
| 1919 | + "Iteration: 19300000\n", |
| 1920 | + "Iteration: 19400000\n", |
| 1921 | + "Iteration: 19500000\n", |
| 1922 | + "Iteration: 19600000\n", |
| 1923 | + "Iteration: 19700000\n", |
| 1924 | + "Iteration: 19800000\n", |
| 1925 | + "Iteration: 19900000\n" |
| 1926 | + ] |
| 1927 | + }, |
| 1928 | + { |
| 1929 | + "data": { |
| 1930 | + "text/plain": [ |
| 1931 | + "({('Mon', 0): {400: 1},\n", |
| 1932 | + " ('Tue', 0): {300: 1},\n", |
| 1933 | + " ('Tue', 100): {300: 1},\n", |
| 1934 | + " ('Tue', 200): {200: 1},\n", |
| 1935 | + " ('Tue', 300): {100: 1},\n", |
| 1936 | + " ('Wed', 0): {400: 1},\n", |
| 1937 | + " ('Wed', 100): {300: 1},\n", |
| 1938 | + " ('Wed', 200): {200: 1},\n", |
| 1939 | + " ('Wed', 300): {100: 1},\n", |
| 1940 | + " ('Thu', 0): {300: 1},\n", |
| 1941 | + " ('Thu', 100): {200: 1},\n", |
| 1942 | + " ('Thu', 200): {100: 1},\n", |
| 1943 | + " ('Thu', 300): {0: 1},\n", |
| 1944 | + " ('Fri', 0): {200: 1},\n", |
| 1945 | + " ('Fri', 100): {100: 1},\n", |
| 1946 | + " ('Fri', 200): {0: 1},\n", |
| 1947 | + " ('Fri', 300): {0: 1},\n", |
| 1948 | + " ('Weekend', 0): {0: 1},\n", |
| 1949 | + " ('Weekend', 100): {0: 1},\n", |
| 1950 | + " ('Weekend', 200): {0: 1},\n", |
| 1951 | + " ('Weekend', 300): {0: 1}},\n", |
| 1952 | + " {('Mon', 0): {0: 2225.749496682385,\n", |
| 1953 | + " 100: 2528.178263359892,\n", |
| 1954 | + " 200: 2752.245336408776,\n", |
| 1955 | + " 300: 2833.598086662411,\n", |
| 1956 | + " 400: 2865.5080973287336},\n", |
| 1957 | + " ('Tue', 0): {0: 1627.1674196319675,\n", |
| 1958 | + " 100: 1926.185189822399,\n", |
| 1959 | + " 200: 2130.63971600556,\n", |
| 1960 | + " 300: 2235.794644930646,\n", |
| 1961 | + " 400: 2202.700921685597},\n", |
| 1962 | + " ('Tue', 100): {0: 2323.642847941712,\n", |
| 1963 | + " 100: 2546.146008882256,\n", |
| 1964 | + " 200: 2622.2014944709003,\n", |
| 1965 | + " 300: 2704.7958165719538,\n", |
| 1966 | + " 400: 2254.9865435917945},\n", |
| 1967 | + " ('Tue', 200): {0: 2938.47212708529,\n", |
| 1968 | + " 100: 2985.6763069672907,\n", |
| 1969 | + " 200: 3045.55602709444,\n", |
| 1970 | + " 300: 2660.793750889116,\n", |
| 1971 | + " 400: 2244.116679273476},\n", |
| 1972 | + " ('Tue', 300): {0: 3397.5493842636856,\n", |
| 1973 | + " 100: 3431.1584693328227,\n", |
| 1974 | + " 200: 3047.963831661575,\n", |
| 1975 | + " 300: 2689.0905262507554,\n", |
| 1976 | + " 400: 2246.0842993310807},\n", |
| 1977 | + " ('Wed', 0): {0: 991.8337704155527,\n", |
| 1978 | + " 100: 1294.3979155570473,\n", |
| 1979 | + " 200: 1499.2384682910836,\n", |
| 1980 | + " 300: 1560.5737610953374,\n", |
| 1981 | + " 400: 1656.9354742311311},\n", |
| 1982 | + " ('Wed', 100): {0: 1693.4281528809047,\n", |
| 1983 | + " 100: 1890.5014859779849,\n", |
| 1984 | + " 200: 1967.5195056845337,\n", |
| 1985 | + " 300: 2030.2875396109434,\n", |
| 1986 | + " 400: 1624.7053132984788},\n", |
| 1987 | + " ('Wed', 200): {0: 2307.350730239611,\n", |
| 1988 | + " 100: 2368.121947542663,\n", |
| 1989 | + " 200: 2439.4451003135055,\n", |
| 1990 | + " 300: 2028.5567501077871,\n", |
| 1991 | + " 400: 1602.4490693607893},\n", |
| 1992 | + " ('Wed', 300): {0: 2766.5347359553866,\n", |
| 1993 | + " 100: 2817.6576462084945,\n", |
| 1994 | + " 200: 2397.5480427541106,\n", |
| 1995 | + " 300: 2028.9023931048234,\n", |
| 1996 | + " 400: 1610.7042344178608},\n", |
| 1997 | + " ('Thu', 0): {0: 388.45762020693445,\n", |
| 1998 | + " 100: 689.9741046142422,\n", |
| 1999 | + " 200: 886.8292737374425,\n", |
| 2000 | + " 300: 1008.8462346115972,\n", |
| 2001 | + " 400: 970.599703355806},\n", |
| 2002 | + " ('Thu', 100): {0: 1086.9408520148895,\n", |
| 2003 | + " 100: 1301.332777514599,\n", |
| 2004 | + " 200: 1405.1825937805977,\n", |
| 2005 | + " 300: 1348.6418726014172,\n", |
| 2006 | + " 400: 992.3726336890564},\n", |
| 2007 | + " ('Thu', 200): {0: 1715.4114166813265,\n", |
| 2008 | + " 100: 1833.5195722683234,\n", |
| 2009 | + " 200: 1741.2757203880324,\n", |
| 2010 | + " 300: 1376.7551643904483,\n", |
| 2011 | + " 400: 957.9607707339657},\n", |
| 2012 | + " ('Thu', 300): {0: 2190.19649451877,\n", |
| 2013 | + " 100: 2125.740810274669,\n", |
| 2014 | + " 200: 1776.8132567876999,\n", |
| 2015 | + " 300: 1408.5495730824664,\n", |
| 2016 | + " 400: 990.4018172404869},\n", |
| 2017 | + " ('Fri', 0): {0: 0.0,\n", |
| 2018 | + " 100: 299.99999999999716,\n", |
| 2019 | + " 200: 406.0476593550646,\n", |
| 2020 | + " 300: 170.46122765548887,\n", |
| 2021 | + " 400: -153.64846976857817},\n", |
| 2022 | + " ('Fri', 100): {0: 699.9999999999943,\n", |
| 2023 | + " 100: 842.0141106267022,\n", |
| 2024 | + " 200: 610.5115569281422,\n", |
| 2025 | + " 300: 292.160669622827,\n", |
| 2026 | + " 400: -113.12406224669776},\n", |
| 2027 | + " ('Fri', 200): {0: 1172.1819744094662,\n", |
| 2028 | + " 100: 1070.907334160906,\n", |
| 2029 | + " 200: 687.7773470555264,\n", |
| 2030 | + " 300: 330.44001007014674,\n", |
| 2031 | + " 400: -75.1010831966216},\n", |
| 2032 | + " ('Fri', 300): {0: 1427.1955441761681,\n", |
| 2033 | + " 100: 1007.1503466766485,\n", |
| 2034 | + " 200: 674.4671172275836,\n", |
| 2035 | + " 300: 278.1467797475504,\n", |
| 2036 | + " 400: -99.78074377598806},\n", |
| 2037 | + " ('Weekend', 0): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 2038 | + " ('Weekend', 100): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 2039 | + " ('Weekend', 200): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0},\n", |
| 2040 | + " ('Weekend', 300): {0: 0, 100: 0, 200: 0, 300: 0, 400: 0}})" |
| 2041 | + ] |
| 2042 | + }, |
| 2043 | + "execution_count": 124, |
| 2044 | + "metadata": {}, |
| 2045 | + "output_type": "execute_result" |
| 2046 | + } |
| 2047 | + ], |
| 2048 | + "source": [ |
| 2049 | + "q_learning(foodtruck, 1, 0.1, 0.01, 20000000)\n" |
| 2050 | + ] |
| 2051 | + }, |
| 2052 | + { |
| 2053 | + "cell_type": "code", |
| 2054 | + "execution_count": null, |
| 2055 | + "metadata": {}, |
| 2056 | + "outputs": [], |
| 2057 | + "source": [] |
| 2058 | + }, |
| 2059 | + { |
| 2060 | + "cell_type": "code", |
| 2061 | + "execution_count": null, |
| 2062 | + "metadata": {}, |
| 2063 | + "outputs": [], |
| 2064 | + "source": [ |
| 2065 | + "Q" |
| 2066 | + ] |
| 2067 | + }, |
| 2068 | + { |
| 2069 | + "cell_type": "markdown", |
| 2070 | + "metadata": {}, |
| 2071 | + "source": [ |
| 2072 | + "{('Mon', 0): 2880.0,\n", |
| 2073 | + " ('Tue', 0): 2250.0,\n", |
| 2074 | + " ('Tue', 100): 2650.0,\n", |
| 2075 | + " ('Tue', 200): 3050.0,\n", |
| 2076 | + " ('Tue', 300): 3450.0,\n", |
| 2077 | + " ('Wed', 0): 1620.0,\n", |
| 2078 | + " ('Wed', 100): 2020.0,\n", |
| 2079 | + " ('Wed', 200): 2420.0,\n", |
| 2080 | + " ('Wed', 300): 2820.0,\n", |
| 2081 | + " ('Thu', 0): 990.0,\n", |
| 2082 | + " ('Thu', 100): 1390.0,\n", |
| 2083 | + " ('Thu', 200): 1790.0,\n", |
| 2084 | + " ('Thu', 300): 2190.0,\n", |
| 2085 | + " ('Fri', 0): 390.00000000000006,\n", |
| 2086 | + " ('Fri', 100): 790.0000000000001,\n", |
| 2087 | + " ('Fri', 200): 1190.0,\n", |
| 2088 | + " ('Fri', 300): 1400.0,\n", |
| 2089 | + " ('Weekend', 0): 0,\n", |
| 2090 | + " ('Weekend', 100): 0,\n", |
| 2091 | + " ('Weekend', 200): 0,\n", |
| 2092 | + " ('Weekend', 300): 0}" |
| 2093 | + ] |
| 2094 | + } |
| 2095 | + ], |
| 2096 | + "metadata": { |
| 2097 | + "kernelspec": { |
| 2098 | + "display_name": "py37ml", |
| 2099 | + "language": "python", |
| 2100 | + "name": "py37ml" |
| 2101 | + }, |
| 2102 | + "language_info": { |
| 2103 | + "codemirror_mode": { |
| 2104 | + "name": "ipython", |
| 2105 | + "version": 3 |
| 2106 | + }, |
| 2107 | + "file_extension": ".py", |
| 2108 | + "mimetype": "text/x-python", |
| 2109 | + "name": "python", |
| 2110 | + "nbconvert_exporter": "python", |
| 2111 | + "pygments_lexer": "ipython3", |
| 2112 | + "version": "3.7.4" |
| 2113 | + } |
| 2114 | + }, |
| 2115 | + "nbformat": 4, |
| 2116 | + "nbformat_minor": 2 |
| 2117 | +} |
0 commit comments