Skip to content

Commit

Permalink
Init weight
Browse files Browse the repository at this point in the history
  • Loading branch information
daivuongktx13 committed Jun 10, 2023
1 parent 80d11c6 commit 75b18ef
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions pyskl/models/gcns/shiftgcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from torch.autograd import Variable
import numpy as np
import math
from mmcv.runner import load_checkpoint

# from .Temporal_shift.cuda.shift import Shift
from ..builder import BACKBONES
from ...utils import Graph
from ...utils import Graph, cache_checkpoint
from .utils import unit_tcn


Expand Down Expand Up @@ -115,7 +116,7 @@ def __init__(self,
graph_cfg,
in_channels=3,
base_channels=64,
data_bn_type='VC',
data_bn_type='MVC',
ch_ratio=2,
num_person=2, # * Only used when data_bn_type == 'MVC'
num_stages=10,
Expand All @@ -125,6 +126,8 @@ def __init__(self,
**kwargs):
super().__init__()

self.pretrained = pretrained

self.graph = Graph(**graph_cfg)
A = torch.tensor(self.graph.A, dtype=torch.float32, requires_grad=False)

Expand All @@ -149,6 +152,11 @@ def __init__(self,
self.l10 = TCN_GCN_unit(256, 256, A)

bn_init(self.data_bn, 1)

def init_weights(self):
if isinstance(self.pretrained, str):
self.pretrained = cache_checkpoint(self.pretrained)
load_checkpoint(self, self.pretrained, strict=False)

def forward(self, x):
N, C, T, V, M = x.size()
Expand Down

0 comments on commit 75b18ef

Please sign in to comment.