forked from ZrrSkywalker/Point-M2AE
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathModelNetDatasetFewShot.py
71 lines (56 loc) · 1.98 KB
/
ModelNetDatasetFewShot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
'''
@author: Xu Yan
@file: ModelNet.py
@time: 2021/3/19 15:51
'''
import os
import numpy as np
import warnings
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset
from .build import DATASETS
from utils.logger import *
import torch
import random
warnings.filterwarnings('ignore')
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
pc = pc / m
return pc
@DATASETS.register_module()
class ModelNetFewShot(Dataset):
def __init__(self, config):
self.root = config.DATA_PATH
self.npoints = config.N_POINTS
self.use_normals = config.USE_NORMALS
self.num_category = config.NUM_CATEGORY
self.process_data = True
self.uniform = True
split = config.subset
self.subset = config.subset
self.way = config.way
self.shot = config.shot
self.fold = config.fold
if self.way == -1 or self.shot == -1 or self.fold == -1:
raise RuntimeError()
self.pickle_path = os.path.join(self.root, f'{self.way}way_{self.shot}shot', f'{self.fold}.pkl')
print_log('Load processed data from %s...' % self.pickle_path, logger = 'ModelNetFewShot')
with open(self.pickle_path, 'rb') as f:
self.dataset = pickle.load(f)[self.subset]
print_log('The size of %s data is %d' % (split, len(self.dataset)), logger = 'ModelNetFewShot')
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
points, label, _ = self.dataset[index]
points[:, 0:3] = pc_normalize(points[:, 0:3])
if not self.use_normals:
points = points[:, 0:3]
pt_idxs = np.arange(0, points.shape[0]) # 2048
if self.subset == 'train':
np.random.shuffle(pt_idxs)
current_points = points[pt_idxs].copy()
current_points = torch.from_numpy(current_points).float()
return 'ModelNet', 'sample', (current_points, label)