Skip to content

Commit

Permalink
model downloader
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Jan 15, 2018
1 parent 5e0132c commit bb6230c
Show file tree
Hide file tree
Showing 13 changed files with 259 additions and 165 deletions.
104 changes: 103 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,106 @@
data/*
#custom
work_dir/*
data/*
config_v0/*
.vscode
model/*

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
46 changes: 26 additions & 20 deletions train.py → main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def __init__(self, arg):
self.load_optimizer()
self.save_arg()

self.print_log('Parameters:\n{}\n'.format(str(vars(arg))))

def load_data(self):
Feeder = import_class(self.arg.feeder)
self.data_loader = dict()
Expand All @@ -54,20 +56,22 @@ def load_data(self):

def load_model(self):
Model = import_class(self.arg.model)
self.model = Model(**self.arg.model_args).cuda(self.arg.device)
self.loss = nn.CrossEntropyLoss().cuda(self.arg.device)
self.model = Model(**self.arg.model_args).cuda(self.arg.device[0])
self.loss = nn.CrossEntropyLoss().cuda(self.arg.device[0])

if self.arg.parallel_device:
if len(self.arg.parallel_device) > 1:
self.model = nn.DataParallel(
self.model,
device_ids=self.arg.parallel_device,
output_device=self.arg.device)
if len(self.arg.device) > 1:
self.model = nn.DataParallel(
self.model,
device_ids=self.arg.device,
output_device=self.arg.device[0])

if self.arg.weights:
print('Load weights from {}.'.format(self.arg.weights))
with open(self.arg.weights, 'r') as f:
weights = pickle.load(f)
if '.pkl' in self.arg.weights:
with open(self.arg.weights, 'r') as f:
weights = pickle.load(f)
else:
weights = torch.load(self.arg.weights)

for w in self.arg.ignore_weights:
if weights.pop(w, None) is not None:
Expand All @@ -86,6 +90,7 @@ def load_model(self):
state.update(weights)
self.model.load_state_dict(state)

print('Done.')
def load_optimizer(self):
if self.arg.optimizer == 'SGD':
self.optimizer = optim.SGD(
Expand Down Expand Up @@ -156,9 +161,9 @@ def train(self, epoch, save_model=False):

# get data
data = Variable(
data.float().cuda(self.arg.device), requires_grad=False)
data.float().cuda(self.arg.device[0]), requires_grad=False)
label = Variable(
label.long().cuda(self.arg.device), requires_grad=False)
label.long().cuda(self.arg.device[0]), requires_grad=False)
timer['dataloader'] += self.split_time()

# forward
Expand All @@ -185,10 +190,11 @@ def train(self, epoch, save_model=False):
self.print_log('\tTime consumption: [Data]{dataloader}, [Network]{model}'.format(**proportion))

if save_model:
model_path = '{}/epoch{}_model.pkl'.format(self.arg.work_dir,
model_path = '{}/epoch{}_model.pt'.format(self.arg.work_dir,
epoch + 1)
with open(model_path, 'w') as f:
pickle.dump(self.model.state_dict(), f)
#with open(model_path, 'w') as f:
# pickle.dump(self.model.state_dict(), f)
torch.save(self.model.state_dict(), model_path)

self.print_log('The model was saved in {}'.format(model_path))

Expand All @@ -200,11 +206,11 @@ def eval(self, epoch, save_score=False, loader_name=['test']):
score_frag = []
for batch_idx, (data, label) in enumerate(self.data_loader[ln]):
data = Variable(
data.float().cuda(self.arg.device),
data.float().cuda(self.arg.device[0]),
requires_grad=False,
volatile=True)
label = Variable(
label.long().cuda(self.arg.device),
label.long().cuda(self.arg.device[0]),
requires_grad=False,
volatile=True)
output = self.model(data)
Expand Down Expand Up @@ -282,8 +288,7 @@ def start(self):

# optim
parser.add_argument('--step', type=int, default=[20, 40, 60], nargs='+')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--parallel-device', type=int, default=None, nargs='+')
parser.add_argument('--device', type=int, default=0, nargs='+')
parser.add_argument('--optimizer', default='SGD')
parser.add_argument('--nesterov', type=str2bool, default=False)
parser.add_argument('--batch-size', type=int, default=256)
Expand All @@ -306,7 +311,8 @@ def start(self):
parser.set_defaults(**default_arg)

arg = parser.parse_args()
print(vars(arg))



processor = Processor(arg)
processor.start()
99 changes: 0 additions & 99 deletions st_gcn/.gitignore
Original file line number Diff line number Diff line change
@@ -1,102 +1,3 @@
# custom
.vscode
*.swp
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
15 changes: 3 additions & 12 deletions st_gcn/feeder/feeder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def __init__(self,
window_size=-1,
temporal_downsample_step=1,
mean_subtraction=0,
num_sample=-1,
sample_name=None,
debug=False):
self.debug = debug
self.mode = mode
Expand All @@ -42,21 +40,14 @@ def __init__(self,
self.window_size = window_size
self.mean_subtraction = mean_subtraction
self.temporal_downsample_step = temporal_downsample_step
self.num_sample = num_sample
self.sample_name = sample_name

self.load_data()

def load_data(self):
# data: N C V T M
self.label = np.load(self.label_path)[0:self.num_sample]
self.data = np.load(self.data_path)[0:self.num_sample]

if self.sample_name != None:
with open(self.sample_name, 'r') as f:
self.sample_name = pickle.load(f)
else:
self.sample_name = [str(i) for i in range(len(self.label))]
with open(self.label_path, 'r') as f:
self.sample_name, self.label = pickle.load(f)
self.data = np.load(self.data_path)

if self.debug:
self.label = self.label[0:100]
Expand Down
8 changes: 6 additions & 2 deletions st_gcn/feeder/feeder_kinetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(self,
random_choose=False,
window_size=-1,
random_shift=False,
random_move=False,
pose_matching=False,
num_person=1,
num_match_trace=2,
Expand All @@ -42,6 +43,7 @@ def __init__(self,
self.label_path = label_path
self.random_choose = random_choose
self.random_shift = random_shift
self.random_move = random_move
self.window_size = window_size
self.temporal_downsample_step = temporal_downsample_step
self.num_sample = num_sample
Expand Down Expand Up @@ -73,7 +75,7 @@ def load_data(self):

# ignore the samples which does not has skeleton sequence
if self.ignore_empty_sample:
self.sample_name = self.sample_name[has_skeleton]
self.sample_name = [s for h, s in zip(has_skeleton, self.sample_name) if h]
self.label = self.label[has_skeleton]

# output data shape (N, C, T, V, M)
Expand Down Expand Up @@ -128,13 +130,15 @@ def __getitem__(self, index):
else:
data_numpy = tools.temporal_slice(data_numpy,
self.temporal_downsample_step)

# data augmentation
if self.random_shift:
data_numpy = tools.random_shift(data_numpy)
if self.random_choose:
data_numpy = tools.random_choose(data_numpy, self.window_size)
elif self.window_size>0:
data_numpy = tools.auto_pading(data_numpy, self.window_size)
if self.random_move:
data_numpy = tools.random_move(data_numpy)

# match poses between 2 frames
if self.pose_matching:
Expand Down
Loading

0 comments on commit bb6230c

Please sign in to comment.