Skip to content

Commit

Permalink
Early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
devforfu committed Apr 24, 2019
1 parent 54d145a commit 6a9b530
Show file tree
Hide file tree
Showing 12 changed files with 611 additions and 85 deletions.
24 changes: 13 additions & 11 deletions dev/01a_callbacks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
" Loss = 10\n",
" Metric = 100\n",
" Schedule = 200\n",
" History = 300\n",
" Tracker = 300\n",
" Logging = 1000\n",
" \n",
" def __call__(self, index=0):\n",
Expand Down Expand Up @@ -116,7 +116,7 @@
"class History(Callback):\n",
" \"\"\"A callback that collects model's metrics during its training.\"\"\"\n",
" \n",
" order = Order.History()\n",
" order = Order.Tracker()\n",
" \n",
" def training_started(self, **kwargs):\n",
" self.recorded = None\n",
Expand Down Expand Up @@ -153,9 +153,9 @@
" \"\"\"\n",
" order = Order.Metric()\n",
" \n",
" def __init__(self, metric_fn: 'callable', name: str=None):\n",
" def __init__(self, metric_fn: 'callable', alias: str=None):\n",
" self.metric_fn = metric_fn\n",
" self.name = name or self.metric_fn.__name__\n",
" self.name = alias or self.metric_fn.__name__\n",
" \n",
" def epoch_started(self, **kwargs):\n",
" self.values = defaultdict(int)\n",
Expand Down Expand Up @@ -197,9 +197,14 @@
" metrics = merge_dicts([p.last_metrics for p in phases])\n",
" values = [f'{k}={autoformat(v)}' for k, v in metrics.items()]\n",
" values_string = ', '.join(values)\n",
" string = f'Epoch: {epoch:4d} | {values_string}\\n'\n",
" self.write(f'Epoch: {epoch:4d} | {values_string}\\n')\n",
" \n",
" def interrupted(self, exc, **kwargs):\n",
" self.write(exc)\n",
" \n",
" def write(self, msg):\n",
" for stream in self.streams:\n",
" stream.write(string)\n",
" stream.write(msg)\n",
" stream.flush()"
]
},
Expand All @@ -216,9 +221,9 @@
" Each observer has a backward reference to its group via 'group' attribute. The group\n",
" keeps a reference to the model which can be used by the \n",
" \"\"\"\n",
" def __init__(self, cbs):\n",
" def __init__(self, cbs, model=None):\n",
" self._init(cbs)\n",
" self._model = None\n",
" self.model = model\n",
" \n",
" def _init(self, cbs):\n",
" if not cbs:\n",
Expand All @@ -234,9 +239,6 @@
" def add(self, cb, *cbs):\n",
" cbs = [cb] + list(cbs)\n",
" self._init(cbs)\n",
" \n",
" def set_model(self, model):\n",
" self._model = model\n",
"\n",
" def training_started(self, **kwargs): self('training_started', **kwargs) \n",
" def training_ended(self, **kwargs): self('training_ended', **kwargs)\n",
Expand Down
3 changes: 3 additions & 0 deletions dev/02b_phase.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@
" for name, values in self.metrics.items():\n",
" metrics[f'{self.name}_{name}'] = values\n",
" return metrics\n",
" \n",
" def get_last_value(self, metric):\n",
" return self.last_metrics[f'{self.name}_{metric}']\n",
"\n",
" def update(self, loss: float):\n",
" self.losses.append(loss)\n",
Expand Down
8 changes: 5 additions & 3 deletions dev/02c_training.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
" model.to(device)\n",
" opt = opt_fn(model, **(opt_params or {}))\n",
" cb = create_callbacks(cbs, default_cb)\n",
" cb.set_model(model)\n",
" cb.model = model\n",
" \n",
" self.model = model\n",
" self.opt = opt\n",
Expand All @@ -94,7 +94,7 @@
" self.train_one_epoch(phases, epoch)\n",
" self.cb.training_ended(phases=phases)\n",
" except TrainingInterrupted as e:\n",
" self.cb.interrupted(reason=e)\n",
" self.cb.interrupted(exc=e)\n",
" \n",
" def train_one_epoch(self, phases: list, curr_epoch: int=1):\n",
" cb, model, opt = self.cb, self.model, self.opt\n",
Expand Down Expand Up @@ -141,7 +141,9 @@
"#export\n",
"class TrainingInterrupted(Exception):\n",
" def __init__(self, context=None):\n",
" self.context = context"
" self.context = context\n",
" def __str__(self):\n",
" return str(self.context)"
]
},
{
Expand Down
75 changes: 22 additions & 53 deletions dev/03a_schedule.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 99,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": 100,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -23,13 +23,13 @@
"import numpy as np\n",
"import pandas as pd\n",
"\n",
"from loop.callbacks import Callback\n",
"from loop.callbacks import Callback, Order\n",
"from loop.utils import calculate_layout"
]
},
{
"cell_type": "code",
"execution_count": 101,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -39,7 +39,7 @@
},
{
"cell_type": "code",
"execution_count": 102,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 103,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -62,7 +62,7 @@
},
{
"cell_type": "code",
"execution_count": 104,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -89,7 +89,7 @@
},
{
"cell_type": "code",
"execution_count": 105,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand All @@ -111,7 +111,7 @@
},
{
"cell_type": "code",
"execution_count": 106,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -144,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 107,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -159,7 +159,7 @@
},
{
"cell_type": "code",
"execution_count": 108,
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -169,7 +169,7 @@
},
{
"cell_type": "code",
"execution_count": 109,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand All @@ -191,7 +191,7 @@
},
{
"cell_type": "code",
"execution_count": 110,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand All @@ -213,14 +213,16 @@
},
{
"cell_type": "code",
"execution_count": 116,
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"#export\n",
"class Scheduler(Callback):\n",
" \"\"\"Updates optimizer's learning rates using provided scheduling function.\"\"\"\n",
" \n",
" order = Order.Schedule()\n",
" \n",
" def __init__(self, opt, schedule: 'callable', params: list=None, mode: str='batch'):\n",
" assert mode in {'batch', 'epoch'}\n",
" params = _make_sched_params(params)\n",
Expand All @@ -229,14 +231,14 @@
" self.params = params\n",
" self.n_steps = 0\n",
" \n",
" def training_started(self, **params):\n",
" def training_started(self, **kwargs):\n",
" self.history = defaultdict(list)\n",
" \n",
" def epoch_started(self, epoch, **params):\n",
" def epoch_started(self, epoch, **kwargs):\n",
" if self.mode == 'epoch': \n",
" self.step(epoch)\n",
" \n",
" def batch_started(self, phase, **params):\n",
" def batch_started(self, phase, **kwargs):\n",
" if self.mode == 'batch': \n",
" if phase.grad:\n",
" self.step(phase.batch_index)\n",
Expand Down Expand Up @@ -276,7 +278,7 @@
},
{
"cell_type": "code",
"execution_count": 117,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -297,7 +299,7 @@
},
{
"cell_type": "code",
"execution_count": 118,
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -312,7 +314,7 @@
},
{
"cell_type": "code",
"execution_count": 119,
"execution_count": 16,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -347,39 +349,6 @@
" \n",
"sched.plot(['lr', 'weight_decay']);"
]
},
{
"cell_type": "code",
"execution_count": 98,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exported: /home/ck/code/loop/dev/00a_annotations.ipynb -> loop/annotations.py\r\n",
"Exported: /home/ck/code/loop/dev/00b_config.ipynb -> loop/config.py\r\n",
"Exported: /home/ck/code/loop/dev/00c_utils.ipynb -> loop/utils.py\r\n",
"Exported: /home/ck/code/loop/dev/01a_callbacks.ipynb -> loop/callbacks.py\r\n",
"Exported: /home/ck/code/loop/dev/01b_modules.ipynb -> loop/modules.py\r\n",
"Exported: /home/ck/code/loop/dev/02a_metrics.ipynb -> loop/metrics.py\r\n",
"Exported: /home/ck/code/loop/dev/02b_phase.ipynb -> loop/phase.py\r\n",
"Exported: /home/ck/code/loop/dev/02c_training.ipynb -> loop/training.py\r\n",
"Exported: /home/ck/code/loop/dev/03a_schedule.ipynb -> loop/schedule.py\r\n",
"9 notebook(s) exported into folder: loop\r\n"
]
}
],
"source": [
"!python export.py -o loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 6a9b530

Please sign in to comment.