Skip to content

Commit

Permalink
fix lint and indentation (pytorch#57)
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko authored Jan 16, 2018
1 parent 4400ed9 commit 7b1dc6c
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 125 deletions.
6 changes: 1 addition & 5 deletions torchnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1 @@
from . import dataset
from . import meter
from . import engine
from . import transform
from . import logger
__all__ = ['dataset', 'meter', 'engine', 'transform', 'logger']
1 change: 0 additions & 1 deletion torchnet/dataset/listdataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from .dataset import Dataset
import torch


class ListDataset(Dataset):
Expand Down
185 changes: 91 additions & 94 deletions torchnet/logger/meterlogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,121 +4,118 @@
# Tsinghua Univ.
# Modified at Dec 12 2017
#
import numpy as np
import torch
import torchnet as tnt
from torchnet.logger import VisdomPlotLogger, VisdomLogger

class MeterLogger(object):

class MeterLogger(object):

def __init__(self, server="http://localhost", port=8097, nclass=21, title="DNN"):
self.nclass = nclass
self.meter = {}
self.server = server
self.port = port
self.nclass = nclass
self.topk = 5 if nclass > 5 else nclass
self.title = title
self.nclass = nclass
self.meter = {}
self.server = server
self.port = port
self.nclass = nclass
self.topk = 5 if nclass > 5 else nclass
self.title = title
self.logger = {'Train':{}, 'Test':{}}
self.timer = tnt.meter.TimeMeter(None)
self.timer = tnt.meter.TimeMeter(None)

def __ver2Tensor(self, target):
target_mat = torch.zeros(target.shape[0], self.nclass)
target_mat = torch.zeros(target.shape[0], self.nclass)
for i,j in enumerate(target):
target_mat[i][j]=1
return target_mat
target_mat[i][j] = 1
return target_mat

def __toTensor(self, var):
if isinstance(var, torch.autograd.Variable):
var = var.data
if not torch.is_tensor(var):
var = torch.from_numpy(var)
return var


if isinstance(var, torch.autograd.Variable):
var = var.data
if not torch.is_tensor(var):
var = torch.from_numpy(var)
return var

def __addlogger(self, meter, ptype):
if ptype == 'line':
opts={'title': self.title+' Train '+ meter}
self.logger['Train'][meter] = VisdomPlotLogger(ptype,server=self.server, port=self.port, opts=opts)
opts={'title': self.title+' Test '+ meter}
self.logger['Test'][meter] = VisdomPlotLogger(ptype,server=self.server, port=self.port, opts=opts)
elif ptype == 'heatmap':
names = list(range(self.nclass))
opts={'title': self.title+' Train '+ meter, 'columnnames':names, 'rownames': names }
self.logger['Train'][meter] = VisdomLogger('heatmap', server=self.server, port=self.port, opts=opts)
opts={'title': self.title+' Test '+ meter, 'columnnames':names, 'rownames': names }
self.logger['Test'][meter] = VisdomLogger('heatmap', server=self.server, port=self.port, opts=opts)
if ptype == 'line':
opts = {'title': self.title + ' Train ' + meter}
self.logger['Train'][meter] = VisdomPlotLogger(ptype,server=self.server, port=self.port, opts=opts)
opts = {'title': self.title + ' Test ' + meter}
self.logger['Test'][meter] = VisdomPlotLogger(ptype,server=self.server, port=self.port, opts=opts)
elif ptype == 'heatmap':
names = list(range(self.nclass))
opts = {'title': self.title + ' Train ' + meter, 'columnnames':names, 'rownames': names}
self.logger['Train'][meter] = VisdomLogger('heatmap', server=self.server, port=self.port, opts=opts)
opts = {'title': self.title + ' Test ' + meter, 'columnnames':names, 'rownames': names}
self.logger['Test'][meter] = VisdomLogger('heatmap', server=self.server, port=self.port, opts=opts)

def __addloss(self, meter):
self.meter[meter] = tnt.meter.AverageValueMeter()
self.__addlogger(meter, 'line')
self.meter[meter] = tnt.meter.AverageValueMeter()
self.__addlogger(meter, 'line')

def __addmeter(self, meter):
if meter == 'accuracy':
self.meter[meter] = tnt.meter.ClassErrorMeter(topk=(1, self.topk), accuracy=True)
self.__addlogger(meter, 'line')
elif meter == 'map':
self.meter[meter] = tnt.meter.mAPMeter()
self.__addlogger(meter, 'line')
elif meter == 'auc':
self.meter[meter] = tnt.meter.AUCMeter()
self.__addlogger(meter, 'line')
elif meter == 'confusion':
self.meter[meter] = tnt.meter.ConfusionMeter(self.nclass, normalized=True)
self.__addlogger(meter, 'heatmap')
if meter == 'accuracy':
self.meter[meter] = tnt.meter.ClassErrorMeter(topk=(1, self.topk), accuracy=True)
self.__addlogger(meter, 'line')
elif meter == 'map':
self.meter[meter] = tnt.meter.mAPMeter()
self.__addlogger(meter, 'line')
elif meter == 'auc':
self.meter[meter] = tnt.meter.AUCMeter()
self.__addlogger(meter, 'line')
elif meter == 'confusion':
self.meter[meter] = tnt.meter.ConfusionMeter(self.nclass, normalized=True)
self.__addlogger(meter, 'heatmap')

def updateMeter(self, output, target, meters={'accuracy'}):
output = self.__toTensor(output)
target = self.__toTensor(target)
for meter in meters:
if not self.meter.has_key(meter):
self.__addmeter(meter)
if meter in ['ap', 'map', 'confusion']:
target_th = self.__ver2Tensor(target)
self.meter[meter].add(output, target_th)
else:
self.meter[meter].add(output, target)

output = self.__toTensor(output)
target = self.__toTensor(target)
for meter in meters:
if meter not in self.meter.keys():
self.__addmeter(meter)
if meter in ['ap', 'map', 'confusion']:
target_th = self.__ver2Tensor(target)
self.meter[meter].add(output, target_th)
else:
self.meter[meter].add(output, target)

def updateLoss(self, loss, meter='loss'):
loss = self.__toTensor(loss)
if not self.meter.has_key(meter):
loss = self.__toTensor(loss)
if meter not in self.meter.keys():
self.__addloss(meter)
self.meter[meter].add(loss[0])
self.meter[meter].add(loss[0])

def resetMeter(self, iepoch, mode='Train'):
self.timer.reset()
for key in self.meter.keys():
val = self.meter[key].value()
val = val[0] if isinstance(val, (list, tuple)) else val
if key in ['confusion','histogram','image']:
self.logger[mode][key].log(val)
else:
self.logger[mode][key].log(iepoch, val)
self.meter[key].reset()
self.timer.reset()
for key in self.meter.keys():
val = self.meter[key].value()
val = val[0] if isinstance(val, (list, tuple)) else val
if key in ['confusion','histogram','image']:
self.logger[mode][key].log(val)
else:
self.logger[mode][key].log(iepoch, val)
self.meter[key].reset()

def printMeter(self, mode, iepoch, ibatch=1, totalbatch=1, meterlist=None):
pstr = "%s:\t[%d][%d/%d] \t"
tval = []
tval.extend([mode, iepoch, ibatch, totalbatch])
if meterlist==None:
meterlist = self.meter.keys()
for meter in meterlist:
if meter in ['confusion','histogram','image']:
continue
if meter == 'accuracy':
pstr += "Acc@1 %.2f%% \t Acc@"+str(self.topk)+" %.2f%% \t"
tval.extend([self.meter[meter].value()[0], self.meter[meter].value()[1]])
elif meter == 'map':
pstr += "mAP %.3f \t"
tval.extend([self.meter[meter].value()])
elif meter == 'auc':
pstr += "AUC %.3f \t"
tval.extend([self.meter[meter].value()])
else :
pstr += meter+" %.3f (%.3f)\t"
tval.extend([self.meter[meter].val, self.meter[meter].mean])
pstr += " %.2fs/its\t"
tval.extend([self.timer.value()])
print(pstr % tuple(tval))
pstr = "%s:\t[%d][%d/%d] \t"
tval = []
tval.extend([mode, iepoch, ibatch, totalbatch])
if meterlist is None:
meterlist = self.meter.keys()
for meter in meterlist:
if meter in ['confusion','histogram','image']:
continue
if meter == 'accuracy':
pstr += "Acc@1 %.2f%% \t Acc@"+str(self.topk)+" %.2f%% \t"
tval.extend([self.meter[meter].value()[0], self.meter[meter].value()[1]])
elif meter == 'map':
pstr += "mAP %.3f \t"
tval.extend([self.meter[meter].value()])
elif meter == 'auc':
pstr += "AUC %.3f \t"
tval.extend([self.meter[meter].value()])
else:
pstr += meter+" %.3f (%.3f)\t"
tval.extend([self.meter[meter].val, self.meter[meter].mean])
pstr += " %.2fs/its\t"
tval.extend([self.timer.value()])
print(pstr % tuple(tval))
26 changes: 13 additions & 13 deletions torchnet/logger/visdomlogger.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
""" Logging to Visdom server """
from collections import defaultdict
import numpy as np
import visdom

from .logger import Logger


class BaseVisdomLogger(Logger):
'''
The base class for logging output to Visdom.
'''
The base class for logging output to Visdom.
***THIS CLASS IS ABSTRACT AND MUST BE SUBCLASSED***
Note that the Visdom server is designed to also handle a server architecture,
and therefore the Visdom server must be running at all times. The server can
be started with
Note that the Visdom server is designed to also handle a server architecture,
and therefore the Visdom server must be running at all times. The server can
be started with
$ python -m visdom.server
and you probably want to run it from screen or tmux.
and you probably want to run it from screen or tmux.
'''

@property
Expand Down Expand Up @@ -49,7 +48,7 @@ def _viz_logger(*args, **kwargs):
return _viz_logger

def log_state(self, state):
""" Gathers the stats from self.trainer.stats and passes them into
""" Gathers the stats from self.trainer.stats and passes them into
self.log, as a list """
results = []
for field_idx, field in enumerate(self.fields):
Expand All @@ -61,9 +60,9 @@ def log_state(self, state):


class VisdomSaver(object):
''' Serialize the state of the Visdom server to disk.
''' Serialize the state of the Visdom server to disk.
Unless you have a fancy schedule, where different are saved with different frequencies,
you probably only need one of these.
you probably only need one of these.
'''

def __init__(self, envs=None, port=8097, server="localhost"):
Expand Down Expand Up @@ -156,13 +155,14 @@ def log(self, *args, **kwargs):

class VisdomTextLogger(BaseVisdomLogger):
'''
Creates a text window in visdom and logs output to it.
The output can be formatted with fancy HTML, and it new output can
Creates a text window in visdom and logs output to it.
The output can be formatted with fancy HTML, and it new output can
be set to 'append' or 'replace' mode.
'''
valid_update_types = ['REPLACE', 'APPEND']

def __init__(self, fields=None, win=None, env=None, opts={}, update_type=valid_update_types[0], port=8097, server="localhost"):
def __init__(self, fields=None, win=None, env=None, opts={}, update_type=valid_update_types[0],
port=8097, server="localhost"):
'''
Args:
fields: Currently unused
Expand Down
4 changes: 1 addition & 3 deletions torchnet/meter/apmeter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import math
from . import meter
import numpy as np
import torch


Expand Down Expand Up @@ -82,8 +81,7 @@ def add(self, output, target, weight=None):
self.scores.storage().resize_(int(new_size + output.numel()))
self.targets.storage().resize_(int(new_size + output.numel()))
if weight is not None:
self.weights.storage().resize_(int(new_weight_size
+ output.size(0)))
self.weights.storage().resize_(int(new_weight_size + output.size(0)))

# store scores and targets
offset = self.scores.size(0) if self.scores.dim() > 0 else 0
Expand Down
1 change: 0 additions & 1 deletion torchnet/meter/aucmeter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import numbers
from . import meter
import numpy as np
Expand Down
3 changes: 0 additions & 3 deletions torchnet/meter/mapmeter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import math
from . import meter, APMeter
import numpy as np
import torch


class mAPMeter(meter.Meter):
Expand Down
9 changes: 4 additions & 5 deletions torchnet/transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from six import iteritems
from .utils.table import canmergetensor as canmerge
from .utils.table import mergetensor as mergetensor
Expand All @@ -22,21 +21,21 @@ def mergekeys(tbl):
if isinstance(tbl, dict):
for idx, elem in tbl.items():
for key, value in elem.items():
if not key in mergetbl:
if key not in mergetbl:
mergetbl[key] = {}
mergetbl[key][idx] = value
elif isinstance(tbl, list):
for elem in tbl:
for key, value in elem.items():
if not key in mergetbl:
if key not in mergetbl:
mergetbl[key] = []
mergetbl[key].append(value)
return mergetbl
return mergekeys


def tableapply(f): return lambda d: dict(
map(lambda kv: (kv[0], f(kv[1])), iteritems(d)))
def tableapply(f):
return lambda d: dict(map(lambda kv: (kv[0], f(kv[1])), iteritems(d)))


def makebatch(merge=None):
Expand Down

0 comments on commit 7b1dc6c

Please sign in to comment.