-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
401 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"from pathlib import Path\n", | ||
"\n", | ||
"import torch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"class Config:\n", | ||
" @property\n", | ||
" def batch_size(self): return 4\n", | ||
" \n", | ||
" @property\n", | ||
" def bs(self): return self.batch_size\n", | ||
"\n", | ||
" @property\n", | ||
" def cpu(self): return torch.device('cpu')\n", | ||
"\n", | ||
" @property\n", | ||
" def gpu_index(self): return 0\n", | ||
" \n", | ||
" @property\n", | ||
" def gpu(self): return torch.device(f'cuda:{self.gpu_index}')\n", | ||
" \n", | ||
" @property\n", | ||
" def device(self): return self.gpu()\n", | ||
" \n", | ||
" @property\n", | ||
" def datasets(self): return Path.home()/'data'" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"defaults = Config()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "fastai (cuda 10)", | ||
"language": "python", | ||
"name": "fastai" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.1" | ||
}, | ||
"toc": { | ||
"base_numbering": 1, | ||
"nav_menu": {}, | ||
"number_sections": true, | ||
"sideBar": true, | ||
"skip_h1_title": false, | ||
"title_cell": "Table of Contents", | ||
"title_sidebar": "Contents", | ||
"toc_cell": false, | ||
"toc_position": {}, | ||
"toc_section_display": true, | ||
"toc_window_display": false | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%reload_ext autoreload\n", | ||
"%autoreload 2" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"from collections import OrderedDict\n", | ||
"from typing import Union, Tuple" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"from loop.config import defaults" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 23, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "SyntaxError", | ||
"evalue": "invalid syntax (<ipython-input-23-3ae1ed1dfdf9>, line 61)", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;36m File \u001b[0;32m\"<ipython-input-23-3ae1ed1dfdf9>\"\u001b[0;36m, line \u001b[0;32m61\u001b[0m\n\u001b[0;31m ('valid', DataLoader(val_ds, bs, shuffle=False, num_workers=val), grad=False)]\u001b[0m\n\u001b[0m ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m invalid syntax\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"#export\n", | ||
"class Phase:\n", | ||
" \"\"\"\n", | ||
" Model training loop phase.\n", | ||
"\n", | ||
" Each model's training loop iteration could be separated into (at least) two\n", | ||
" phases: training and validation. The instances of this class track\n", | ||
" metrics and counters, related to the specific phase, and keep the reference\n", | ||
" to subset of data, used during phase.\n", | ||
" \"\"\"\n", | ||
" def __init__(self, name: str, loader: 'DataLoader', grad: bool=True):\n", | ||
" self.name = name\n", | ||
" self.loader = loader\n", | ||
" self.grad = grad\n", | ||
" self.batch_loss = None\n", | ||
" self.batch_index = 0\n", | ||
" self.rolling_loss = 0\n", | ||
" self.losses = []\n", | ||
" self.metrics = OrderedDict()\n", | ||
"\n", | ||
" @property\n", | ||
" def last_loss(self):\n", | ||
" return self.losses[-1] if self.losses else None\n", | ||
"\n", | ||
" @property\n", | ||
" def last_metrics(self):\n", | ||
" metrics = OrderedDict()\n", | ||
" metrics[f'{self.name}_loss'] = self.last_loss\n", | ||
" for name, values in self.metrics.items():\n", | ||
" metrics[f'{self.name}_{name}'] = values[-1]\n", | ||
" return metrics\n", | ||
"\n", | ||
" @property\n", | ||
" def metrics_history(self):\n", | ||
" metrics = OrderedDict()\n", | ||
" for name, values in self.metrics.items():\n", | ||
" metrics[f'{self.name}_{name}'] = values\n", | ||
" return metrics\n", | ||
"\n", | ||
" def update(self, loss):\n", | ||
" self.losses.append(loss)\n", | ||
"\n", | ||
" def update_metric(self, name, value):\n", | ||
" if name not in self.metrics:\n", | ||
" self.metrics[name] = []\n", | ||
" self.metrics[name].append(value)\n", | ||
" \n", | ||
" @staticmethod\n", | ||
" def make_train_valid(trn_ds, val_ds, \n", | ||
" bs: int=defaults.bs,\n", | ||
" num_workers: Union[Tuple, int]=0):\n", | ||
" \"\"\"Creates two loop's phases, train and valid.\n", | ||
" \n", | ||
" The phases are thin wrappers on top of data loaders intended to track\n", | ||
" additional information gathered during model's fitting process, like \n", | ||
" loss, performance metrics, etc.\n", | ||
" \"\"\"\n", | ||
" trn, val = unwrap(num_workers, 2)\n", | ||
" defs = [\n", | ||
" ('train', DataLoader(trn_ds, bs, shuffle=True, num_workers=trn)),\n", | ||
" ('valid', DataLoader(val_ds, bs, shuffle=False, num_workers=val)]\n", | ||
" phs = OrderedDict()\n", | ||
" for name, loader in defs:\n", | ||
" phs[name] = Phase(name, loader)\n", | ||
" return phs" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 24, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"import torch\n", | ||
"from torch.utils.data import DataLoader, TensorDataset" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def mock_loader(): return DataLoader(TensorDataset(torch.randn((1000, 10))))\n", | ||
"train = Phase('train', mock_loader())\n", | ||
"assert train.name == 'train'\n", | ||
"assert train.loader is not None" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"def is_scalar(obj):\n", | ||
" return isinstance(obj, (int, float, str, complex))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#export\n", | ||
"def unwrap(obj, pad=1):\n", | ||
" \"\"\"Convenience function to unwrap collections and broadcast scalars.\"\"\"\n", | ||
" if is_scalar(obj): \n", | ||
" return [obj]*pad\n", | ||
" return obj" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 10, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from torchvision.datasets import MNIST" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"root = defaults.datasets/'mnist'\n", | ||
"trn_ds = MNIST(root, train=True)\n", | ||
"val_ds = MNIST(root, train=False)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"ename": "AssertionError", | ||
"evalue": "", | ||
"output_type": "error", | ||
"traceback": [ | ||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | ||
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", | ||
"\u001b[0;32m<ipython-input-22-d481221797f1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mphases\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'valid'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloader\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0mphases\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'train'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mphases\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'valid'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", | ||
"\u001b[0;31mAssertionError\u001b[0m: " | ||
] | ||
} | ||
], | ||
"source": [ | ||
"phases = Phase.make_train_valid(trn_ds, val_ds)\n", | ||
"assert len(phases) == 2\n", | ||
"assert phases['train'].loader is not None\n", | ||
"assert phases['valid'].loader is not None\n", | ||
"assert phases['train'].grad \n", | ||
"assert not phases['valid'].grad" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"OrderedDict([('train', <__main__.Phase at 0x7fdf48212e10>),\n", | ||
" ('valid', <__main__.Phase at 0x7fdf48212c18>)])" | ||
] | ||
}, | ||
"execution_count": 20, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "fastai (cuda 10)", | ||
"language": "python", | ||
"name": "fastai" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.7.1" | ||
}, | ||
"toc": { | ||
"base_numbering": 1, | ||
"nav_menu": {}, | ||
"number_sections": true, | ||
"sideBar": true, | ||
"skip_h1_title": false, | ||
"title_cell": "Table of Contents", | ||
"title_sidebar": "Contents", | ||
"toc_cell": false, | ||
"toc_position": {}, | ||
"toc_section_display": true, | ||
"toc_window_display": false | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.