Skip to content

Commit

Permalink
torchlight
Browse files Browse the repository at this point in the history
  • Loading branch information
yysijie committed Jun 6, 2018
1 parent a3ec0a4 commit 6242477
Show file tree
Hide file tree
Showing 9 changed files with 261 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ model/*
*.pt
*.caffemodel
resource/media
cache/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ The downloaded models will be stored under ```./models```.
<!-- If you get an error message after running above command, you can also obtain models from [GoogleDrive](https://drive.google.com/open?id=1koTe3ces06NCntVdfLxF7O2Z4l_5csnX) or [BaiduYun](https://pan.baidu.com/s/1dwKG2TLvG-R1qeIiE4MjeA#list/path=%2FShare%2FAAAI18%2Fst-gcn&parentPath=%2FShare), and manually put them into ```./models```. -->

## Demo
Our graph convolutional networks represent human skeleton sequences by
**spatial temporal graph**, which maintain the spatial structure in the network propagation. To visualize how ST-GCN exploit local correlation and pattern, we compute the feature vector magnitude of each node in the final spatial temporal graph, and overlay them on the original video. **Openpose** should be ready for extracting human skeletons from videos as the input of our model.
To visualize how ST-GCN exploit local correlation and local pattern, we compute the feature vector magnitude of each node in the final spatial temporal graph, and overlay them on the original video. **Openpose** should be ready for extracting human skeletons from videos as the input of our model.

Run the demo by this command:
```
Expand Down
Empty file added demo.sh
Empty file.
6 changes: 3 additions & 3 deletions processor/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def start(self):
render_pose=0)
command_line = openpose + ' '
command_line += ' '.join([f'--{k} {v}' for k, v in openpose_args.items()])
# shutil.rmtree(output_snippets_dir, ignore_errors=True)
# os.makedirs(output_snippets_dir)
# os.system(command_line)
shutil.rmtree(output_snippets_dir, ignore_errors=True)
os.makedirs(output_snippets_dir)
os.system(command_line)

# pack openpose ouputs
video = utils.video.get_video_frames(self.arg.video)
Expand Down
5 changes: 4 additions & 1 deletion tools/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def stgcn_visualize(pose, edge, feature, video, label=None, label_sequence=None)
text = frame * 0
for m in range(M):
score = pose[2, t, :, m].mean()
if score <0.1:
if score <0.3:
continue

for i, j in edge:
Expand Down Expand Up @@ -59,6 +59,9 @@ def stgcn_visualize(pose, edge, feature, video, label=None, label_sequence=None)
feature = np.abs(feature)
feature = feature/feature.mean()
for m in range(M):
score = pose[2, t, :, m].mean()
if score <0.3:
continue

f = feature[t//4, :, m] ** 5
if f.mean() != 0:
Expand Down
8 changes: 8 additions & 0 deletions torchlight/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from setuptools import find_packages, setup

setup(
name='torchlight',
version='1.0',
description='A mini framework for pytorch',
packages=find_packages(),
install_requires=[])
8 changes: 8 additions & 0 deletions torchlight/torchlight/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .io import IO
from .io import str2bool
from .io import str2dict
from .io import DictAction
from .io import import_class
from .gpu import visible_gpu
from .gpu import occupy_gpu
from .gpu import ngpu
35 changes: 35 additions & 0 deletions torchlight/torchlight/gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import torch


def visible_gpu(gpus):
"""
set visible gpu.
can be a single id, or a list
return a list of new gpus ids
"""
gpus = [gpus] if isinstance(gpus, int) else list(gpus)
os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(list(map(str, gpus)))
return list(range(len(gpus)))


def ngpu(gpus):
"""
count how many gpus used.
"""
gpus = [gpus] if isinstance(gpus, int) else list(gpus)
return len(gpus)


def occupy_gpu(gpus=None):
"""
make program appear on nvidia-smi.
"""
if gpus is None:
torch.zeros(1).cuda()
else:
gpus = [gpus] if isinstance(gpus, int) else list(gpus)
for g in gpus:
torch.zeros(1).cuda(g)
201 changes: 201 additions & 0 deletions torchlight/torchlight/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
#!/usr/bin/env python
import argparse
import os
import sys
import traceback
import time
import warnings
import pickle
from collections import OrderedDict
import yaml
import numpy as np
# torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

from torchpack.runner.hooks import PaviLogger

with warnings.catch_warnings():
warnings.filterwarnings("ignore",category=FutureWarning)
import h5py

class IO():
def __init__(self, work_dir, save_log=True, print_log=True):
self.work_dir = work_dir
self.save_log = save_log
self.print_to_screen = print_log
self.cur_time = time.time()
self.split_timer = {}
self.pavi_logger = None
self.session_file = None
self.model_text = ''

def log(self, *args, **kwargs):
try:
if self.pavi_logger is None:
url = 'http://pavi.parrotsdnn.org/log'
with open(self.session_file, 'r') as f:
info = dict(
session_file=self.session_file,
session_text=f.read(),
model_text=self.model_text)
self.pavi_logger = PaviLogger(url)
self.pavi_logger.connect(self.work_dir, info=info)
self.pavi_logger.log(*args, **kwargs)
except: #pylint: disable=W0702
pass

def load_model(self, model, **model_args):
Model = import_class(model)
model = Model(**model_args)
self.model_text += '\n\n' + str(model)
return model

def load_weights(self, model, weights_path, ignore_weights=None):
if ignore_weights is None:
ignore_weights = []
if isinstance(ignore_weights, str):
ignore_weights = [ignore_weights]

self.print_log(f'Load weights from {weights_path}.')
weights = torch.load(weights_path)
weights = OrderedDict([[k.split('module.')[-1],
v.cpu()] for k, v in weights.items()])

# filter weights
for i in ignore_weights:
ignore_name = list()
for w in weights:
if w.find(i) == 0:
ignore_name.append(w)
for n in ignore_name:
weights.pop(n)
self.print_log(f'Filter [{i}] remove weights [{n}].')

for w in weights:
self.print_log(f'Load weights [{w}].')

try:
model.load_state_dict(weights)
except (KeyError, RuntimeError):
state = model.state_dict()
diff = list(set(state.keys()).difference(set(weights.keys())))
for d in diff:
self.print_log(f'Can not find weights [{d}].')
state.update(weights)
model.load_state_dict(state)
return model

def save_pkl(self, result, filename):
with open(f'{self.work_dir}/{filename}', 'wb') as f:
pickle.dump(result, f)

def save_h5(self, result, filename):
with h5py.File(f'{self.work_dir}/{filename}', 'w') as f:
for k in result.keys():
f[k] = result[k]

def save_model(self, model, name):
model_path = f'{self.work_dir}/{name}'
state_dict = model.state_dict()
weights = OrderedDict([[''.join(k.split('module.')),
v.cpu()] for k, v in state_dict.items()])
torch.save(weights, model_path)
self.print_log(f'The model has been saved as {model_path}.')

def save_arg(self, arg):

self.session_file = f'{self.work_dir}/config.yaml'

# save arg
arg_dict = vars(arg)
if not os.path.exists(self.work_dir):
os.makedirs(self.work_dir)
with open(self.session_file, 'w') as f:
f.write(f"# command line: {' '.join(sys.argv)}\n\n")
yaml.dump(arg_dict, f, default_flow_style=False, indent=4)

def print_log(self, str, print_time=True):
if print_time:
# localtime = time.asctime(time.localtime(time.time()))
str = time.strftime("[%m.%d.%y|%X] ", time.localtime()) + str

if self.print_to_screen:
print(str)
if self.save_log:
with open(f'{self.work_dir}/log.txt', 'a') as f:
print(str, file=f)

def init_timer(self, *name):
self.record_time()
self.split_timer = {k: 0.0000001 for k in name}

def check_time(self, name):
self.split_timer[name] += self.split_time()

def record_time(self):
self.cur_time = time.time()
return self.cur_time

def split_time(self):
split_time = time.time() - self.cur_time
self.record_time()
return split_time

def print_timer(self):
proportion = {
k: f'{int(round(v * 100 / sum(self.split_timer.values()))):02d}%'
for k, v in self.split_timer.items()
}
self.print_log(f'Time consumption:')
for k in proportion:
self.print_log(
f'\t[{k}][{proportion[k]}]: {self.split_timer[k]:.4f}')


def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')


def str2dict(v):
return eval(f'dict({v})') #pylint: disable=W0123


def _import_class_0(name):
components = name.split('.')
mod = __import__(components[0])
for comp in components[1:]:
mod = getattr(mod, comp)
return mod


def import_class(import_str):
mod_str, _sep, class_str = import_str.rpartition('.')
__import__(mod_str)
try:
return getattr(sys.modules[mod_str], class_str)
except AttributeError:
raise ImportError('Class %s cannot be found (%s)' %
(class_str,
traceback.format_exception(*sys.exc_info())))


class DictAction(argparse.Action):
def __init__(self, option_strings, dest, nargs=None, **kwargs):
if nargs is not None:
raise ValueError("nargs not allowed")
super(DictAction, self).__init__(option_strings, dest, **kwargs)

def __call__(self, parser, namespace, values, option_string=None):
input_dict = eval(f'dict({values})') #pylint: disable=W0123
output_dict = getattr(namespace, self.dest)
for k in input_dict:
output_dict[k] = input_dict[k]
setattr(namespace, self.dest, output_dict)

0 comments on commit 6242477

Please sign in to comment.