forked from GuyTevet/motion-diffusion-model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensors.py
88 lines (69 loc) · 2.9 KB
/
tensors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
def lengths_to_mask(lengths, max_len):
# max_len = max(lengths)
mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
return mask
def collate_tensors(batch):
dims = batch[0].dim()
max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
size = (len(batch),) + tuple(max_size)
canvas = batch[0].new_zeros(size=size)
for i, b in enumerate(batch):
sub_tensor = canvas[i]
for d in range(dims):
sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
sub_tensor.add_(b)
return canvas
def collate(batch):
notnone_batches = [b for b in batch if b is not None]
databatch = [b['inp'] for b in notnone_batches]
if 'lengths' in notnone_batches[0]:
lenbatch = [b['lengths'] for b in notnone_batches]
else:
lenbatch = [len(b['inp'][0][0]) for b in notnone_batches]
databatchTensor = collate_tensors(databatch)
lenbatchTensor = torch.as_tensor(lenbatch)
maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) # unqueeze for broadcasting
motion = databatchTensor
cond = {'y': {'mask': maskbatchTensor, 'lengths': lenbatchTensor}}
if 'text' in notnone_batches[0]:
textbatch = [b['text'] for b in notnone_batches]
cond['y'].update({'text': textbatch})
if 'tokens' in notnone_batches[0]:
textbatch = [b['tokens'] for b in notnone_batches]
cond['y'].update({'tokens': textbatch})
if 'neg_text' in notnone_batches[0]:
neg_textbatch = [b['neg_text'] for b in notnone_batches]
cond['y'].update({'neg_text': neg_textbatch})
if 'action' in notnone_batches[0]:
actionbatch = [b['action'] for b in notnone_batches]
cond['y'].update({'action': torch.as_tensor(actionbatch).unsqueeze(1)})
# collate action textual names
if 'action_text' in notnone_batches[0]:
action_text = [b['action_text']for b in notnone_batches]
cond['y'].update({'action_text': action_text})
return motion, cond
# an adapter to our collate func
# def t2m_collate(batch):
# # batch.sort(key=lambda x: x[3], reverse=True)
# adapted_batch = [{
# 'inp': torch.tensor(b[4].T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen]
# 'text': b[2], #b[0]['caption']
# 'tokens': b[6],
# 'lengths': b[5],
# } for b in batch]
# return collate(adapted_batch)
def t2m_collate(batch):
adapted_batch = []
for b in batch:
data_dict = {
'inp': torch.tensor(b[4].T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen]
'text': b[2], # caption
'tokens': b[6],
'lengths': b[5],
}
if len(b) > 7:
# 'negative_text' is present
data_dict['neg_text'] = b[7]
adapted_batch.append(data_dict)
return collate(adapted_batch)