forked from lanpa/tensorboardX
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_hparams.py
29 lines (23 loc) · 928 Bytes
/
test_hparams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import unittest
import numpy as np
from tensorboardX import SummaryWriter
hparam = {'lr': [0.1, 0.01, 0.001],
'bsize': [1, 2, 4],
'n_hidden': [100, 200],
'bn': [True, False]}
metrics = {'accuracy', 'loss'}
def train(lr, bsize, n_hidden):
x = lr + bsize + n_hidden
return x, x*5
class HparamsTest(unittest.TestCase):
def test_smoke(self):
i = 0
with SummaryWriter() as w:
for lr in hparam['lr']:
for bsize in hparam['bsize']:
for n_hidden in hparam['n_hidden']:
for bn in hparam['bn']:
accu, loss = train(lr, bsize, n_hidden)
i = i + 1
w.add_hparams({'lr': lr, 'bsize': bsize, 'n_hidden': n_hidden, 'bn': bn},
{'accuracy': accu, 'loss': loss}, name="trial"+str(i))