diff --git a/dev/01a_callbacks.ipynb b/dev/01a_callbacks.ipynb index 9b8e7c4..f2dc0a3 100644 --- a/dev/01a_callbacks.ipynb +++ b/dev/01a_callbacks.ipynb @@ -34,7 +34,7 @@ " Loss = 10\n", " Metric = 100\n", " Schedule = 200\n", - " History = 300\n", + " Tracker = 300\n", " Logging = 1000\n", " \n", " def __call__(self, index=0):\n", @@ -116,7 +116,7 @@ "class History(Callback):\n", " \"\"\"A callback that collects model's metrics during its training.\"\"\"\n", " \n", - " order = Order.History()\n", + " order = Order.Tracker()\n", " \n", " def training_started(self, **kwargs):\n", " self.recorded = None\n", @@ -153,9 +153,9 @@ " \"\"\"\n", " order = Order.Metric()\n", " \n", - " def __init__(self, metric_fn: 'callable', name: str=None):\n", + " def __init__(self, metric_fn: 'callable', alias: str=None):\n", " self.metric_fn = metric_fn\n", - " self.name = name or self.metric_fn.__name__\n", + " self.name = alias or self.metric_fn.__name__\n", " \n", " def epoch_started(self, **kwargs):\n", " self.values = defaultdict(int)\n", @@ -197,9 +197,14 @@ " metrics = merge_dicts([p.last_metrics for p in phases])\n", " values = [f'{k}={autoformat(v)}' for k, v in metrics.items()]\n", " values_string = ', '.join(values)\n", - " string = f'Epoch: {epoch:4d} | {values_string}\\n'\n", + " self.write(f'Epoch: {epoch:4d} | {values_string}\\n')\n", + " \n", + " def interrupted(self, exc, **kwargs):\n", + " self.write(exc)\n", + " \n", + " def write(self, msg):\n", " for stream in self.streams:\n", - " stream.write(string)\n", + " stream.write(msg)\n", " stream.flush()" ] }, @@ -216,9 +221,9 @@ " Each observer has a backward reference to its group via 'group' attribute. The group\n", " keeps a reference to the model which can be used by the \n", " \"\"\"\n", - " def __init__(self, cbs):\n", + " def __init__(self, cbs, model=None):\n", " self._init(cbs)\n", - " self._model = None\n", + " self.model = model\n", " \n", " def _init(self, cbs):\n", " if not cbs:\n", @@ -234,9 +239,6 @@ " def add(self, cb, *cbs):\n", " cbs = [cb] + list(cbs)\n", " self._init(cbs)\n", - " \n", - " def set_model(self, model):\n", - " self._model = model\n", "\n", " def training_started(self, **kwargs): self('training_started', **kwargs) \n", " def training_ended(self, **kwargs): self('training_ended', **kwargs)\n", diff --git a/dev/02b_phase.ipynb b/dev/02b_phase.ipynb index 99ba76c..a227fb9 100644 --- a/dev/02b_phase.ipynb +++ b/dev/02b_phase.ipynb @@ -105,6 +105,9 @@ " for name, values in self.metrics.items():\n", " metrics[f'{self.name}_{name}'] = values\n", " return metrics\n", + " \n", + " def get_last_value(self, metric):\n", + " return self.last_metrics[f'{self.name}_{metric}']\n", "\n", " def update(self, loss: float):\n", " self.losses.append(loss)\n", diff --git a/dev/02c_training.ipynb b/dev/02c_training.ipynb index bad0a04..5780521 100644 --- a/dev/02c_training.ipynb +++ b/dev/02c_training.ipynb @@ -74,7 +74,7 @@ " model.to(device)\n", " opt = opt_fn(model, **(opt_params or {}))\n", " cb = create_callbacks(cbs, default_cb)\n", - " cb.set_model(model)\n", + " cb.model = model\n", " \n", " self.model = model\n", " self.opt = opt\n", @@ -94,7 +94,7 @@ " self.train_one_epoch(phases, epoch)\n", " self.cb.training_ended(phases=phases)\n", " except TrainingInterrupted as e:\n", - " self.cb.interrupted(reason=e)\n", + " self.cb.interrupted(exc=e)\n", " \n", " def train_one_epoch(self, phases: list, curr_epoch: int=1):\n", " cb, model, opt = self.cb, self.model, self.opt\n", @@ -141,7 +141,9 @@ "#export\n", "class TrainingInterrupted(Exception):\n", " def __init__(self, context=None):\n", - " self.context = context" + " self.context = context\n", + " def __str__(self):\n", + " return str(self.context)" ] }, { diff --git a/dev/03a_schedule.ipynb b/dev/03a_schedule.ipynb index 0cd19fb..212286d 100644 --- a/dev/03a_schedule.ipynb +++ b/dev/03a_schedule.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 99, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -13,7 +13,7 @@ }, { "cell_type": "code", - "execution_count": 100, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -23,13 +23,13 @@ "import numpy as np\n", "import pandas as pd\n", "\n", - "from loop.callbacks import Callback\n", + "from loop.callbacks import Callback, Order\n", "from loop.utils import calculate_layout" ] }, { "cell_type": "code", - "execution_count": 101, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ @@ -48,7 +48,7 @@ }, { "cell_type": "code", - "execution_count": 103, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -62,7 +62,7 @@ }, { "cell_type": "code", - "execution_count": 104, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 105, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -111,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -144,7 +144,7 @@ }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -159,7 +159,7 @@ }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -169,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 109, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -191,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 110, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -213,7 +213,7 @@ }, { "cell_type": "code", - "execution_count": 116, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -221,6 +221,8 @@ "class Scheduler(Callback):\n", " \"\"\"Updates optimizer's learning rates using provided scheduling function.\"\"\"\n", " \n", + " order = Order.Schedule()\n", + " \n", " def __init__(self, opt, schedule: 'callable', params: list=None, mode: str='batch'):\n", " assert mode in {'batch', 'epoch'}\n", " params = _make_sched_params(params)\n", @@ -229,14 +231,14 @@ " self.params = params\n", " self.n_steps = 0\n", " \n", - " def training_started(self, **params):\n", + " def training_started(self, **kwargs):\n", " self.history = defaultdict(list)\n", " \n", - " def epoch_started(self, epoch, **params):\n", + " def epoch_started(self, epoch, **kwargs):\n", " if self.mode == 'epoch': \n", " self.step(epoch)\n", " \n", - " def batch_started(self, phase, **params):\n", + " def batch_started(self, phase, **kwargs):\n", " if self.mode == 'batch': \n", " if phase.grad:\n", " self.step(phase.batch_index)\n", @@ -276,7 +278,7 @@ }, { "cell_type": "code", - "execution_count": 117, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -297,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 118, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -312,7 +314,7 @@ }, { "cell_type": "code", - "execution_count": 119, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -347,39 +349,6 @@ " \n", "sched.plot(['lr', 'weight_decay']);" ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exported: /home/ck/code/loop/dev/00a_annotations.ipynb -> loop/annotations.py\r\n", - "Exported: /home/ck/code/loop/dev/00b_config.ipynb -> loop/config.py\r\n", - "Exported: /home/ck/code/loop/dev/00c_utils.ipynb -> loop/utils.py\r\n", - "Exported: /home/ck/code/loop/dev/01a_callbacks.ipynb -> loop/callbacks.py\r\n", - "Exported: /home/ck/code/loop/dev/01b_modules.ipynb -> loop/modules.py\r\n", - "Exported: /home/ck/code/loop/dev/02a_metrics.ipynb -> loop/metrics.py\r\n", - "Exported: /home/ck/code/loop/dev/02b_phase.ipynb -> loop/phase.py\r\n", - "Exported: /home/ck/code/loop/dev/02c_training.ipynb -> loop/training.py\r\n", - "Exported: /home/ck/code/loop/dev/03a_schedule.ipynb -> loop/schedule.py\r\n", - "9 notebook(s) exported into folder: loop\r\n" - ] - } - ], - "source": [ - "!python export.py -o loop" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/dev/03b_early_stopping.ipynb b/dev/03b_early_stopping.ipynb new file mode 100644 index 0000000..258f371 --- /dev/null +++ b/dev/03b_early_stopping.ipynb @@ -0,0 +1,288 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "%reload_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "from pathlib import Path\n", + "\n", + "import torch\n", + "\n", + "from loop.callbacks import Callback, Order\n", + "from loop.utils import autoformat" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "class BestMetric(Callback):\n", + " \"\"\"A callback that memorizes the best value of metric.\n", + " \n", + " The class is intended to be a base class for other types of metric trackers that\n", + " perform some action when metric stops to improve.\n", + " \"\"\"\n", + "\n", + " def __init__(self, phase: str='valid', metric: str='loss', better: 'callable'=min):\n", + " self.phase = phase\n", + " self.metric = metric\n", + " self.better = better\n", + " \n", + " @property\n", + " def formatted_best(self):\n", + " return f'{self.phase}_{self.metric}={autoformat(self.best_value)}'\n", + " \n", + " def training_started(self, **kwargs):\n", + " self.best_value = None\n", + " \n", + " def epoch_started(self, **kwargs):\n", + " self.updated = False\n", + " \n", + " def phase_ended(self, phase, **kwargs):\n", + " ignore = phase.name != self.phase\n", + " if not ignore:\n", + " self.update_best(phase, **kwargs)\n", + " return ignore\n", + " \n", + " def update_best(self, phase, **kwargs):\n", + " new_value = phase.get_last_value(self.metric)\n", + " if self.best_value is None:\n", + " self.best_value = new_value\n", + " else:\n", + " self.best_value = self.better(self.best_value, new_value)\n", + " self.updated = self.best_value == new_value\n", + " return self.updated" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "class EarlyStopping(BestMetric):\n", + " \n", + " order = Order.Tracker(1)\n", + " \n", + " def __init__(self, patience: int=1, **kwargs):\n", + " super().__init__(**kwargs)\n", + " self.patience = patience\n", + " \n", + " def training_started(self, **kwargs):\n", + " super().training_started(**kwargs)\n", + " self.trials = 0\n", + " self.running = True\n", + " \n", + " def phase_ended(self, phase, **kwargs):\n", + " ignore = super().phase_ended(phase=phase, **kwargs)\n", + " if ignore: return\n", + " if self.updated: \n", + " self.trials = 0\n", + " else:\n", + " self.trials += 1\n", + " if self.trials >= self.patience:\n", + " breakpoint()\n", + " self.running = False\n", + " \n", + " def epoch_ended(self, phases, epoch, **kwargs):\n", + " super().epoch_ended(phases=phases, epoch=epoch, **kwargs)\n", + " if self.running: return\n", + " from loop.training import TrainingInterrupted\n", + " msg = f'Early stopping at epoch {epoch} with {self.formatted_best}'\n", + " raise TrainingInterrupted(msg)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "class ModelSaver(BestMetric):\n", + " \n", + " order = Order.Tracker(2)\n", + " \n", + " def __init__(self, mode: str='every', root: Path=Path.cwd(), **kwargs):\n", + " super().__init__(**kwargs)\n", + " assert mode in {'every', 'best'}\n", + " self.root = Path(root)\n", + " self.mode = mode\n", + " \n", + " def training_started(self, **kwargs):\n", + " super().training_started(**kwargs)\n", + " if not self.root.exists():\n", + " self.root.mkdir(parents=True)\n", + " self.last_saved = None\n", + " \n", + " def epoch_ended(self, phases, epoch, **kwargs):\n", + " super().epoch_ended(phases=phases, epoch=epoch, **kwargs)\n", + " fname = f'model__{self.formatted_best}__epoch={epoch}.pth'\n", + " if self.mode == 'every' or self.updated:\n", + " path = self.root/fname\n", + " torch.save(self.group.model, path)\n", + " self.last_saved = path\n", + " \n", + " def load_last_saved_state(self, model=None):\n", + " if self.last_saved is None:\n", + " raise ValueError('nothing was saved during training')\n", + " model = model or self.group.model\n", + " if model is None:\n", + " raise ValueError('no model provided to restore the saved state')\n", + " model.load_state_dict(torch.load(self.last_saved))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Exception ignored in: \n", + "Traceback (most recent call last):\n", + " File \"/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py\", line 717, in __del__\n", + " self._shutdown_workers()\n", + " File \"/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py\", line 713, in _shutdown_workers\n", + " w.join()\n", + " File \"/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py\", line 138, in join\n", + " assert self._parent_pid == os.getpid(), 'can only join a child process'\n", + "AssertionError: can only join a child process\n", + "Exception ignored in: \n", + "Traceback (most recent call last):\n", + " File \"/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py\", line 717, in __del__\n", + " self._shutdown_workers()\n", + " File \"/home/ck/anaconda3/envs/fastai_10/lib/python3.7/site-packages/torch/utils/data/dataloader.py\", line 713, in _shutdown_workers\n", + " w.join()\n", + " File \"/home/ck/anaconda3/envs/fastai_10/lib/python3.7/multiprocessing/process.py\", line 138, in join\n", + " assert self._parent_pid == os.getpid(), 'can only join a child process'\n", + "AssertionError: can only join a child process\n" + ] + }, + { + "ename": "KeyError", + "evalue": "'valid_acc'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 6\u001b[0m ModelSaver(mode='best', root=Path.home()/'models')]\n\u001b[1;32m 7\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mtrain_classifier_with_callbacks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mTinyNet\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcbs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcbs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m10000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mloop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'history'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mplot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/code/loop/dev/loop/testing.py\u001b[0m in \u001b[0;36mtrain_classifier_with_callbacks\u001b[0;34m(model, cbs, n, bs)\u001b[0m\n\u001b[1;32m 38\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain_classifier_with_callbacks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcbs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1024\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0mloop\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mLoop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcbs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcbs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcross_entropy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m \u001b[0mloop\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit_datasets\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mget_mnist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 41\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloop\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/code/loop/dev/loop/training.py\u001b[0m in \u001b[0;36mfit_datasets\u001b[0;34m(self, trn_ds, val_ds, epochs, batch_size)\u001b[0m\n\u001b[1;32m 50\u001b[0m phases = Phase.make_train_valid(\n\u001b[1;32m 51\u001b[0m trn_ds, val_ds, bs=batch_size, num_workers=defaults.n_jobs)\n\u001b[0;32m---> 52\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 53\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphases\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/code/loop/dev/loop/training.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, phases, epochs)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_started\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepochs\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_one_epoch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtraining_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mTrainingInterrupted\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/code/loop/dev/loop/training.py\u001b[0m in \u001b[0;36mtrain_one_epoch\u001b[0;34m(self, phases, curr_epoch)\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 93\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 94\u001b[0;31m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mphase_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 95\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 96\u001b[0m \u001b[0mcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mepoch_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphases\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcurr_epoch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/code/loop/dev/loop/callbacks.py\u001b[0m in \u001b[0;36mphase_ended\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 193\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mepoch_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'epoch_ended'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 194\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mphase_started\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'phase_started'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 195\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0mphase_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'phase_ended'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 196\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbatch_started\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'batch_started'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 197\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbatch_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'batch_ended'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/code/loop/dev/loop/callbacks.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, name, **kwargs)\u001b[0m\n\u001b[1;32m 222\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmethod\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 223\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 224\u001b[0;31m \u001b[0mmethod\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m\u001b[0m in \u001b[0;36mphase_ended\u001b[0;34m(self, phase, **kwargs)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 15\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mphase_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m \u001b[0mignore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mphase_ended\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 17\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mignore\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 18\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdated\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mphase_ended\u001b[0;34m(self, phase, **kwargs)\u001b[0m\n\u001b[1;32m 25\u001b[0m \u001b[0mignore\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 26\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mignore\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mupdate_best\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mphase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 28\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mignore\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m\u001b[0m in \u001b[0;36mupdate_best\u001b[0;34m(self, phase, **kwargs)\u001b[0m\n\u001b[1;32m 29\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 30\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mupdate_best\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 31\u001b[0;31m \u001b[0mnew_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mphase\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_last_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 32\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_value\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbest_value\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnew_value\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/code/loop/dev/loop/phase.py\u001b[0m in \u001b[0;36mget_last_value\u001b[0;34m(self, metric)\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget_last_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 82\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlast_metrics\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34mf'{self.name}_{metric}'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 83\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mupdate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mfloat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyError\u001b[0m: 'valid_acc'" + ] + } + ], + "source": [ + "from loop.testing import train_classifier_with_callbacks\n", + "from loop.modules import TinyNet\n", + "\n", + "cbs = [C.Average(accuracy, alias='acc'), \n", + " EarlyStopping(metric='acc', patience=1),\n", + " ModelSaver(mode='best', metric='acc', root=Path.home()/'models')]\n", + "\n", + "train_classifier_with_callbacks(TinyNet(1), cbs=cbs, n=10000, bs=100)\n", + "\n", + "loop.cb['history'].plot()\n", + "cbs['early_stopping'].load_last_saved_state()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exported: /home/ck/code/loop/dev/00a_annotations.ipynb -> loop/annotations.py\r\n", + "Exported: /home/ck/code/loop/dev/00b_config.ipynb -> loop/config.py\r\n", + "Exported: /home/ck/code/loop/dev/00c_utils.ipynb -> loop/utils.py\r\n", + "Exported: /home/ck/code/loop/dev/01a_callbacks.ipynb -> loop/callbacks.py\r\n", + "Exported: /home/ck/code/loop/dev/01b_modules.ipynb -> loop/modules.py\r\n", + "Exported: /home/ck/code/loop/dev/02a_metrics.ipynb -> loop/metrics.py\r\n", + "Exported: /home/ck/code/loop/dev/02b_phase.ipynb -> loop/phase.py\r\n", + "Exported: /home/ck/code/loop/dev/02c_training.ipynb -> loop/training.py\r\n", + "Exported: /home/ck/code/loop/dev/03a_schedule.ipynb -> loop/schedule.py\r\n", + "Exported: /home/ck/code/loop/dev/03b_early_stopping.ipynb -> loop/early_stopping.py\r\n", + "Exported: /home/ck/code/loop/dev/99_testing.ipynb -> loop/testing.py\r\n", + "11 notebook(s) exported into folder: loop\r\n" + ] + } + ], + "source": [ + "!python export.py -o loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fastai (cuda 10)", + "language": "python", + "name": "fastai" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dev/99_testing.ipynb b/dev/99_testing.ipynb new file mode 100644 index 0000000..5c0ed17 --- /dev/null +++ b/dev/99_testing.ipynb @@ -0,0 +1,97 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "from torch.nn import functional as F\n", + "\n", + "from loop import callbacks as C\n", + "from loop.config import defaults\n", + "from loop.metrics import accuracy\n", + "from loop.modules import TinyNet\n", + "from loop.training import Loop" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "def get_mnist():\n", + " from torchvision.datasets import MNIST\n", + " from torchvision import transforms as T\n", + " \n", + " root = defaults.datasets/'mnist'\n", + "\n", + " mnist_stats = ([0.15]*1, [0.15]*1)\n", + "\n", + " trn_ds = MNIST(root, train=True, transform=T.Compose([\n", + " T.Resize(32),\n", + " T.RandomAffine(5, translate=(0.05, 0.05), scale=(0.9, 1.1)),\n", + " T.ToTensor(),\n", + " T.Normalize(*mnist_stats)\n", + " ]))\n", + " val_ds = MNIST(root, train=False, transform=T.Compose([\n", + " T.Resize(32),\n", + " T.ToTensor(),\n", + " T.Normalize(*mnist_stats)\n", + " ]))\n", + " \n", + " return trn_ds, val_ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#export\n", + "def train_classifier_with_callbacks(model, cbs, n, bs=1024):\n", + " loop = Loop(model, cbs=cbs, loss_fn=F.cross_entropy)\n", + " loop.fit_datasets(*get_mnist(), epochs=n, batch_size=bs)\n", + " return loop" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "fastai (cuda 10)", + "language": "python", + "name": "fastai" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/dev/loop/callbacks.py b/dev/loop/callbacks.py index 964dc2f..d5c59cc 100644 --- a/dev/loop/callbacks.py +++ b/dev/loop/callbacks.py @@ -23,7 +23,7 @@ class Order(IntFlag): Loss = 10 Metric = 100 Schedule = 200 - History = 300 + Tracker = 300 Logging = 1000 def __call__(self, index=0): @@ -84,7 +84,7 @@ def epoch_ended(self, phases, **kwargs): class History(Callback): """A callback that collects model's metrics during its training.""" - order = Order.History() + order = Order.Tracker() def training_started(self, **kwargs): self.recorded = None @@ -114,9 +114,9 @@ class Average(Callback): """ order = Order.Metric() - def __init__(self, metric_fn: 'callable', name: str=None): + def __init__(self, metric_fn: 'callable', alias: str=None): self.metric_fn = metric_fn - self.name = name or self.metric_fn.__name__ + self.name = alias or self.metric_fn.__name__ def epoch_started(self, **kwargs): self.values = defaultdict(int) @@ -151,9 +151,14 @@ def epoch_ended(self, phases, epoch, **kwargs): metrics = merge_dicts([p.last_metrics for p in phases]) values = [f'{k}={autoformat(v)}' for k, v in metrics.items()] values_string = ', '.join(values) - string = f'Epoch: {epoch:4d} | {values_string}\n' + self.write(f'Epoch: {epoch:4d} | {values_string}\n') + + def interrupted(self, exc, **kwargs): + self.write(exc) + + def write(self, msg): for stream in self.streams: - stream.write(string) + stream.write(msg) stream.flush() @@ -163,9 +168,9 @@ class Group(Callback): Each observer has a backward reference to its group via 'group' attribute. The group keeps a reference to the model which can be used by the """ - def __init__(self, cbs): + def __init__(self, cbs, model=None): self._init(cbs) - self._model = None + self.model = model def _init(self, cbs): if not cbs: @@ -182,9 +187,6 @@ def add(self, cb, *cbs): cbs = [cb] + list(cbs) self._init(cbs) - def set_model(self, model): - self._model = model - def training_started(self, **kwargs): self('training_started', **kwargs) def training_ended(self, **kwargs): self('training_ended', **kwargs) def epoch_started(self, **kwargs): self('epoch_started', **kwargs) diff --git a/dev/loop/early_stopping.py b/dev/loop/early_stopping.py new file mode 100644 index 0000000..297b3aa --- /dev/null +++ b/dev/loop/early_stopping.py @@ -0,0 +1,114 @@ +# ----------------------------------------- +# THIS FILE WAS AUTOGENERATED! DO NOT EDIT! +# ----------------------------------------- +# file to edit: 03b_early_stopping.ipynb + +from pathlib import Path + +import torch + +from loop.callbacks import Callback, Order +from loop.utils import autoformat + + +class BestMetric(Callback): + """A callback that memorizes the best value of metric. + + The class is intended to be a base class for other types of metric trackers that + perform some action when metric stops to improve. + """ + + def __init__(self, phase: str='valid', metric: str='loss', better: 'callable'=min): + self.phase = phase + self.metric = metric + self.better = better + + @property + def formatted_best(self): + return f'{self.phase}_{self.metric}={autoformat(self.best_value)}' + + def training_started(self, **kwargs): + self.best_value = None + + def epoch_started(self, **kwargs): + self.updated = False + + def phase_ended(self, phase, **kwargs): + ignore = phase.name != self.phase + if not ignore: + self.update_best(phase, **kwargs) + return ignore + + def update_best(self, phase, **kwargs): + new_value = phase.get_last_value(self.metric) + if self.best_value is None: + self.best_value = new_value + else: + self.best_value = self.better(self.best_value, new_value) + self.updated = self.best_value == new_value + return self.updated + + +class EarlyStopping(BestMetric): + + order = Order.Tracker(1) + + def __init__(self, patience: int=1, **kwargs): + super().__init__(**kwargs) + self.patience = patience + + def training_started(self, **kwargs): + super().training_started(**kwargs) + self.trials = 0 + self.running = True + + def phase_ended(self, phase, **kwargs): + ignore = super().phase_ended(phase=phase, **kwargs) + if ignore: return + if self.updated: + self.trials = 0 + else: + self.trials += 1 + if self.trials >= self.patience: + breakpoint() + self.running = False + + def epoch_ended(self, phases, epoch, **kwargs): + super().epoch_ended(phases=phases, epoch=epoch, **kwargs) + if self.running: return + from loop.training import TrainingInterrupted + msg = f'Early stopping at epoch {epoch} with {self.formatted_best}' + raise TrainingInterrupted(msg) + + +class ModelSaver(BestMetric): + + order = Order.Tracker(2) + + def __init__(self, mode: str='every', root: Path=Path.cwd(), **kwargs): + super().__init__(**kwargs) + assert mode in {'every', 'best'} + self.root = Path(root) + self.mode = mode + + def training_started(self, **kwargs): + super().training_started(**kwargs) + if not self.root.exists(): + self.root.mkdir(parents=True) + self.last_saved = None + + def epoch_ended(self, phases, epoch, **kwargs): + super().epoch_ended(phases=phases, epoch=epoch, **kwargs) + fname = f'model__{self.formatted_best}__epoch={epoch}.pth' + if self.mode == 'every' or self.updated: + path = self.root/fname + torch.save(self.group.model, path) + self.last_saved = path + + def load_last_saved_state(self, model=None): + if self.last_saved is None: + raise ValueError('nothing was saved during training') + model = model or self.group.model + if model is None: + raise ValueError('no model provided to restore the saved state') + model.load_state_dict(torch.load(self.last_saved)) diff --git a/dev/loop/phase.py b/dev/loop/phase.py index 71b3855..e2afab9 100644 --- a/dev/loop/phase.py +++ b/dev/loop/phase.py @@ -78,6 +78,9 @@ def metrics_history(self): metrics[f'{self.name}_{name}'] = values return metrics + def get_last_value(self, metric): + return self.last_metrics[f'{self.name}_{metric}'] + def update(self, loss: float): self.losses.append(loss) diff --git a/dev/loop/schedule.py b/dev/loop/schedule.py index 7839599..42bc560 100644 --- a/dev/loop/schedule.py +++ b/dev/loop/schedule.py @@ -8,7 +8,7 @@ import numpy as np import pandas as pd -from loop.callbacks import Callback +from loop.callbacks import Callback, Order from loop.utils import calculate_layout @@ -72,6 +72,8 @@ def __call__(self, t): class Scheduler(Callback): """Updates optimizer's learning rates using provided scheduling function.""" + order = Order.Schedule() + def __init__(self, opt, schedule: 'callable', params: list=None, mode: str='batch'): assert mode in {'batch', 'epoch'} params = _make_sched_params(params) @@ -80,14 +82,14 @@ def __init__(self, opt, schedule: 'callable', params: list=None, mode: str='batc self.params = params self.n_steps = 0 - def training_started(self, **params): + def training_started(self, **kwargs): self.history = defaultdict(list) - def epoch_started(self, epoch, **params): + def epoch_started(self, epoch, **kwargs): if self.mode == 'epoch': self.step(epoch) - def batch_started(self, phase, **params): + def batch_started(self, phase, **kwargs): if self.mode == 'batch': if phase.grad: self.step(phase.batch_index) @@ -121,7 +123,8 @@ def plot(self, params=None, axes=None): f, axes = plt.subplots(*calculate_layout(n)) for i, param in enumerate(params): df.plot(x='iteration', y=param, ax=axes.flat[i]) - return f + f.tight_layout() + return axes def _make_sched_params(params): diff --git a/dev/loop/testing.py b/dev/loop/testing.py new file mode 100644 index 0000000..24e5025 --- /dev/null +++ b/dev/loop/testing.py @@ -0,0 +1,41 @@ +# ----------------------------------------- +# THIS FILE WAS AUTOGENERATED! DO NOT EDIT! +# ----------------------------------------- +# file to edit: 99_testing.ipynb + +from torch.nn import functional as F + +from loop import callbacks as C +from loop.config import defaults +from loop.metrics import accuracy +from loop.modules import TinyNet +from loop.training import Loop + + +def get_mnist(): + from torchvision.datasets import MNIST + from torchvision import transforms as T + + root = defaults.datasets/'mnist' + + mnist_stats = ([0.15]*1, [0.15]*1) + + trn_ds = MNIST(root, train=True, transform=T.Compose([ + T.Resize(32), + T.RandomAffine(5, translate=(0.05, 0.05), scale=(0.9, 1.1)), + T.ToTensor(), + T.Normalize(*mnist_stats) + ])) + val_ds = MNIST(root, train=False, transform=T.Compose([ + T.Resize(32), + T.ToTensor(), + T.Normalize(*mnist_stats) + ])) + + return trn_ds, val_ds + + +def train_classifier_with_callbacks(model, cbs, n, bs=1024): + loop = Loop(model, cbs=cbs, loss_fn=F.cross_entropy) + loop.fit_datasets(*get_mnist(), epochs=n, batch_size=bs) + return loop diff --git a/dev/loop/training.py b/dev/loop/training.py index 2cdff9f..f79de62 100644 --- a/dev/loop/training.py +++ b/dev/loop/training.py @@ -38,7 +38,7 @@ def __init__(self, model: nn.Module, cbs: list=None, model.to(device) opt = opt_fn(model, **(opt_params or {})) cb = create_callbacks(cbs, default_cb) - cb.set_model(model) + cb.model = model self.model = model self.opt = opt @@ -58,7 +58,7 @@ def train(self, phases: list, epochs: int=1): self.train_one_epoch(phases, epoch) self.cb.training_ended(phases=phases) except TrainingInterrupted as e: - self.cb.interrupted(reason=e) + self.cb.interrupted(exc=e) def train_one_epoch(self, phases: list, curr_epoch: int=1): cb, model, opt = self.cb, self.model, self.opt @@ -99,6 +99,8 @@ def train_one_epoch(self, phases: list, curr_epoch: int=1): class TrainingInterrupted(Exception): def __init__(self, context=None): self.context = context + def __str__(self): + return str(self.context) def place_and_unwrap(batch, device):