-
Notifications
You must be signed in to change notification settings - Fork 1
/
runner.py
55 lines (44 loc) · 1.58 KB
/
runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import logging, coloredlogs
from typing import Callable
import random
import torch
from easydict import EasyDict
from learner import Learner
from evals import alphaBlendEval, rotationEval
logger = logging.getLogger()
logger.addHandler(logging.StreamHandler())
coloredlogs.install(
logger=logger,
fmt="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y/%m/%d %H:%M:%S",
level=logging.INFO,
)
def run(params: EasyDict, model: torch.nn.Module, datasetLoader: Callable):
# Use CUDA
if not torch.cuda.is_available():
raise RuntimeError("Cuda not found!")
model.cuda()
seed = random.randint(1, 10000)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
train_loader, test_loader1 = datasetLoader(params.batchSize)
if not params.modelIsTrained or not params.resume:
logger.info(f"Training:")
main_learner = Learner(
params,
model=model,
trainloader=train_loader,
testloader=test_loader1,
)
main_learner.learn()
logger.info(f"Learning session complete")
_, test_loader2 = datasetLoader(params.batchSize)
if "alphaBlending" in params.evals:
logger.info(f"Alpha blending evaluation using alpha={params.alpha}")
alphaBlendEval(
model, test_loader1, test_loader2, params.alpha, params.metricParams
)
if "rotation" in params.evals:
logger.info(f"Rotation evaluation using angle={params.angle}")
rotationEval(model, test_loader2, params.angle, params.metricParams)