forked from xinge008/Cylinder3D
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcylinder_spconv_3d.py
46 lines (32 loc) · 1.26 KB
/
cylinder_spconv_3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
# -*- coding:utf-8 -*-
# author: Xinge
# @file: cylinder_spconv_3d.py
from torch import nn
REGISTERED_MODELS_CLASSES = {}
def register_model(cls, name=None):
global REGISTERED_MODELS_CLASSES
if name is None:
name = cls.__name__
assert name not in REGISTERED_MODELS_CLASSES, f"exist class: {REGISTERED_MODELS_CLASSES}"
REGISTERED_MODELS_CLASSES[name] = cls
return cls
def get_model_class(name):
global REGISTERED_MODELS_CLASSES
assert name in REGISTERED_MODELS_CLASSES, f"available class: {REGISTERED_MODELS_CLASSES}"
return REGISTERED_MODELS_CLASSES[name]
@register_model
class cylinder_asym(nn.Module):
def __init__(self,
cylin_model,
segmentator_spconv,
sparse_shape,
):
super().__init__()
self.name = "cylinder_asym"
self.cylinder_3d_generator = cylin_model
self.cylinder_3d_spconv_seg = segmentator_spconv
self.sparse_shape = sparse_shape
def forward(self, train_pt_fea_ten, train_vox_ten, batch_size):
coords, features_3d = self.cylinder_3d_generator(train_pt_fea_ten, train_vox_ten)
spatial_features = self.cylinder_3d_spconv_seg(features_3d, coords, batch_size)
return spatial_features