Skip to content

Commit

Permalink
pytorch hook
Browse files Browse the repository at this point in the history
  • Loading branch information
chunhuizhang committed Jun 23, 2022
1 parent b2432d5 commit 462bd2c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
17 changes: 17 additions & 0 deletions cv/pretrained/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

import timm
import torch
from torch import nn


model_name = 'xception41'
# model_name = 'resnet18'
model = timm.create_model(model_name, pretrained=True)

input = torch.randn(2, 3, 299, 299)

o1 = model(input)
print(o1.shape)

o2 = model.forward_features(input)
print(o2.shape)
45 changes: 45 additions & 0 deletions learn_torch/utils/hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@

import timm
import torch
from torch import nn


def print_shape(m, i, o):
#m: module, i: input, o: output
# print(m, i[0].shape, o.shape)
print(i[0].shape, '=>', o.shape)




def get_children(model: nn.Module):
# get children form model!
children = list(model.children())
flatt_children = []
if children == []:
# if model has no children; model is last child! :O
return model
else:
# look for children from children... to the last child!
for child in children:
try:
flatt_children.extend(get_children(child))
except TypeError:
flatt_children.append(get_children(child))
return flatt_children


model_name = 'vgg11'
model = timm.create_model(model_name, pretrained=True)

flatt_children = get_children(model)
for layer in flatt_children:
layer.register_forward_hook(print_shape)

# for layer in model.children():
# layer.register_forward_hook(print_shape)

# 4d: batch*channel*width*height
batch_input = torch.randn(4, 3, 299, 299)

model(batch_input)

0 comments on commit 462bd2c

Please sign in to comment.