forked from Seanny123/da-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
38 lines (30 loc) · 971 Bytes
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import logging
import os
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable
from constants import device
def setup_log(tag='VOC_TOPICS'):
# create logger
logger = logging.getLogger(tag)
# logger.handlers = []
logger.propagate = False
logger.setLevel(logging.DEBUG)
# create console handler and set level to debug
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
# create formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# add formatter to ch
ch.setFormatter(formatter)
# add ch to logger
# logger.handlers = []
logger.addHandler(ch)
return logger
def save_or_show_plot(file_nm: str, save: bool):
if save:
plt.savefig(os.path.join(os.path.dirname(__file__), "plots", file_nm))
else:
plt.show()
def numpy_to_tvar(x):
return Variable(torch.from_numpy(x).type(torch.FloatTensor).to(device))