-
Notifications
You must be signed in to change notification settings - Fork 8
/
double_samplers.py
72 lines (61 loc) · 2.79 KB
/
double_samplers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from __future__ import absolute_import
from __future__ import division
from collections import defaultdict
import numpy as np
import copy
import random
import torch
from torch.utils.data.sampler import Sampler
class RandomIdentitySampler(Sampler):
"""
Randomly sample N identities, then for each identity,
randomly sample K instances, therefore batch size is N*K.
Args:
- data_source (Dataset): dataset to sample from.
- num_instances (int): number of instances per identity in a batch.
- batch_size (int): number of examples in a batch.
"""
def __init__(self, data_source, batch_size, num_instances):
self.data_source = data_source # train_all数据集:751个id
self.batch_size = batch_size # 每个批次图片总数
self.num_instances = num_instances # 每个id选取的图片张数
self.num_pids_per_batch = self.batch_size // self.num_instances # 每个批次包含多少id
self.index_dic = defaultdict(list) # 创建index字典
for index, (_, _, pid ,_ ,_) in enumerate(self.data_source): # pid是list,pid=[0,1,2,3,...,750]作为index_dic的keys
self.index_dic[pid].append(index) # keys对应的value为train_all中每个pid对应的图片序号
self.pids = list(self.index_dic.keys())
# 计算采样后数据集样本总数:self.length=3004=751*4
self.length = 0
for pid in self.pids:
idxs = self.index_dic[pid]
num = len(idxs)
if num < self.num_instances:
num = self.num_instances
self.length += num - num % self.num_instances
def __iter__(self):
# __iter__()函数中具体说明了图片采样的过程
batch_idxs_dict = defaultdict(list)
for pid in self.pids:
idxs = copy.deepcopy(self.index_dic[pid])
if len(idxs) < self.num_instances:
idxs = np.random.choice(idxs, size=self.num_instances, replace=True)
random.shuffle(idxs)
batch_idxs = []
for idx in idxs:
batch_idxs.append(idx)
if len(batch_idxs) == self.num_instances:
batch_idxs_dict[pid].append(batch_idxs)
batch_idxs = []
avai_pids = copy.deepcopy(self.pids)
final_idxs = []
while len(avai_pids) >= self.num_pids_per_batch:
selected_pids = random.sample(avai_pids, self.num_pids_per_batch)
for pid in selected_pids:
batch_idxs = batch_idxs_dict[pid].pop(0)
final_idxs.extend(batch_idxs)
if len(batch_idxs_dict[pid]) == 0:
avai_pids.remove(pid)
return iter(final_idxs)
def __len__(self):
# 返回采样处理后数据集包含的图片数:3004
return self.length