-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
executable file
·294 lines (244 loc) · 9.77 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import os
import re
import sys
import pickle
import random
import getpass
import argparse
import subprocess
import numpy as np
import torch
from .logger import create_logger
FALSY_STRINGS = {'off', 'false', '0'}
TRUTHY_STRINGS = {'on', 'true', '1'}
DUMP_PATH = '/checkpoint/%s/dumped' % getpass.getuser()
DYNAMIC_COEFF = ['lambda_clm', 'lambda_mlm', 'lambda_pc', 'lambda_ae', 'lambda_mt', 'lambda_bt', 'lambda_mass', 'lambda_span', 'lambda_rat','lambda_rabt','lambda_xbt']
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def bool_flag(s):
"""
Parse boolean arguments from the command line.
"""
if s.lower() in FALSY_STRINGS:
return False
elif s.lower() in TRUTHY_STRINGS:
return True
else:
raise argparse.ArgumentTypeError("Invalid value for a boolean flag!")
def initialize_exp(params):
"""
Initialize the experience:
- dump parameters
- create a logger
"""
# dump parameters
get_dump_path(params)
pickle.dump(params, open(os.path.join(params.dump_path, 'params.pkl'), 'wb'))
# get running command
command = ["python", sys.argv[0]]
for x in sys.argv[1:]:
if x.startswith('--'):
assert '"' not in x and "'" not in x
command.append(x)
else:
assert "'" not in x
if re.match('^[a-zA-Z0-9_]+$', x):
command.append("%s" % x)
else:
command.append("'%s'" % x)
command = ' '.join(command)
params.command = command + ' --exp_id "%s"' % params.exp_id
# check experiment name
assert len(params.exp_name.strip()) > 0
# create a logger
logger = create_logger(os.path.join(params.dump_path, 'train.log'), rank=getattr(params, 'global_rank', 0))
logger.info("============ Initialized logger ============")
logger.info("\n".join("%s: %s" % (k, str(v))
for k, v in sorted(dict(vars(params)).items())))
logger.info("The experiment will be stored in %s\n" % params.dump_path)
logger.info("Running command: %s" % command)
logger.info("")
return logger
def get_dump_path(params):
"""
Create a directory to store the experiment.
"""
dump_path = DUMP_PATH if params.dump_path == '' else params.dump_path
assert len(params.exp_name) > 0
# create the sweep path if it does not exist
sweep_path = os.path.join(dump_path, params.exp_name)
if not os.path.exists(sweep_path):
subprocess.Popen("mkdir -p %s" % sweep_path, shell=True).wait()
# create an ID for the job if it is not given in the parameters.
# if we run on the cluster, the job ID is the one of Chronos.
# otherwise, it is randomly generated
if params.exp_id == '':
chronos_job_id = os.environ.get('CHRONOS_JOB_ID')
slurm_job_id = os.environ.get('SLURM_JOB_ID')
assert chronos_job_id is None or slurm_job_id is None
exp_id = chronos_job_id if chronos_job_id is not None else slurm_job_id
if exp_id is None:
chars = 'abcdefghijklmnopqrstuvwxyz0123456789'
while True:
exp_id = ''.join(random.choice(chars) for _ in range(10))
if not os.path.isdir(os.path.join(sweep_path, exp_id)):
break
else:
assert exp_id.isdigit()
params.exp_id = exp_id
# create the dump folder / update parameters
params.dump_path = os.path.join(sweep_path, params.exp_id)
if not os.path.isdir(params.dump_path):
subprocess.Popen("mkdir -p %s" % params.dump_path, shell=True).wait()
def to_cuda(*args):
"""
Move tensors to CUDA.
"""
return [None if x is None else x.cuda() for x in args]
def restore_segmentation(path):
"""
Take a file segmented with BPE and restore it to its original segmentation.
"""
assert os.path.isfile(path)
restore_cmd = "sed -i -r 's/(@@ )|(@@ ?$)//g' %s"
subprocess.Popen(restore_cmd % path, shell=True).wait()
def parse_lambda_config(params):
"""
Parse the configuration of lambda coefficient (for scheduling).
x = "3" # lambda will be a constant equal to x
x = "0:1,1000:0" # lambda will start from 1 and linearly decrease to 0 during the first 1000 iterations
x = "0:0,1000:0,2000:1" # lambda will be equal to 0 for the first 1000 iterations, then will linearly increase to 1 until iteration 2000
"""
for name in DYNAMIC_COEFF:
x = getattr(params, name)
split = x.split(',')
if len(split) == 1:
setattr(params, name, float(x))
setattr(params, name + '_config', None)
else:
split = [s.split(':') for s in split]
assert all(len(s) == 2 for s in split)
assert all(k.isdigit() for k, _ in split)
assert all(int(split[i][0]) < int(split[i + 1][0]) for i in range(len(split) - 1))
setattr(params, name, float(split[0][1]))
setattr(params, name + '_config', [(int(k), float(v)) for k, v in split])
def get_lambda_value(config, n_iter):
"""
Compute a lambda value according to its schedule configuration.
"""
ranges = [i for i in range(len(config) - 1) if config[i][0] <= n_iter < config[i + 1][0]]
if len(ranges) == 0:
assert n_iter >= config[-1][0]
return config[-1][1]
assert len(ranges) == 1
i = ranges[0]
x_a, y_a = config[i]
x_b, y_b = config[i + 1]
return y_a + (n_iter - x_a) * float(y_b - y_a) / float(x_b - x_a)
def update_lambdas(params, n_iter):
"""
Update all lambda coefficients.
"""
for name in DYNAMIC_COEFF:
config = getattr(params, name + '_config')
if config is not None:
setattr(params, name, get_lambda_value(config, n_iter))
def set_sampling_probs(data, params):
"""
Set the probability of sampling specific languages / language pairs during training.
"""
coeff = params.lg_sampling_factor
if coeff == -1:
return
assert coeff > 0
# monolingual data
params.mono_list = [k for k, v in data['mono_stream'].items() if 'train' in v]
if len(params.mono_list) > 0:
probs = np.array([1.0 * len(data['mono_stream'][lang]['train']) for lang in params.mono_list])
probs /= probs.sum()
probs = np.array([p ** coeff for p in probs])
probs /= probs.sum()
params.mono_probs = probs
# parallel data
params.para_list = [k for k, v in data['para'].items() if 'train' in v]
if len(params.para_list) > 0:
probs = np.array([1.0 * len(data['para'][(l1, l2)]['train']) for (l1, l2) in params.para_list])
probs /= probs.sum()
probs = np.array([p ** coeff for p in probs])
probs /= probs.sum()
params.para_probs = probs
def concat_batches(x1, len1, lang1_id, x2, len2, lang2_id, pad_idx, eos_idx, reset_positions):
"""
Concat batches with different languages.
"""
assert reset_positions is False or lang1_id != lang2_id
lengths = len1 + len2
if not reset_positions:
lengths -= 1
slen, bs = lengths.max().item(), lengths.size(0)
x = x1.new(slen, bs).fill_(pad_idx)
x[:len1.max().item()].copy_(x1)
positions = torch.arange(slen)[:, None].repeat(1, bs).to(x1.device)
langs = x1.new(slen, bs).fill_(lang1_id)
for i in range(bs):
l1 = len1[i] if reset_positions else len1[i] - 1
x[l1:l1 + len2[i], i].copy_(x2[:len2[i], i])
if reset_positions:
positions[l1:, i] -= len1[i]
langs[l1:, i] = lang2_id
assert (x == eos_idx).long().sum().item() == (4 if reset_positions else 3) * bs
return x, lengths, positions, langs
def truncate(x, lengths, max_len, eos_index):
"""
Truncate long sentences.
"""
if lengths.max().item() > max_len:
x = x[:max_len].clone()
lengths = lengths.clone()
for i in range(len(lengths)):
if lengths[i] > max_len:
lengths[i] = max_len
x[max_len - 1, i] = eos_index
return x, lengths
def shuf_order(langs, params=None, n=5):
"""
Randomize training order.
"""
if len(langs) == 0:
return []
if params is None:
return [langs[i] for i in np.random.permutation(len(langs))]
# sample monolingual and parallel languages separately
mono = [l1 for l1, l2 in langs if l2 is None]
para = [(l1, l2) for l1, l2 in langs if l2 is not None]
# uniform / weighted sampling
if params.lg_sampling_factor == -1:
p_mono = None
p_para = None
else:
p_mono = np.array([params.mono_probs[params.mono_list.index(k)] for k in mono])
p_para = np.array([params.para_probs[params.para_list.index(tuple(sorted(k)))] for k in para])
p_mono = p_mono / p_mono.sum()
p_para = p_para / p_para.sum()
s_mono = [mono[i] for i in np.random.choice(len(mono), size=min(n, len(mono)), p=p_mono, replace=True)] if len(mono) > 0 else []
s_para = [para[i] for i in np.random.choice(len(para), size=min(n, len(para)), p=p_para, replace=True)] if len(para) > 0 else []
assert len(s_mono) + len(s_para) > 0
return [(lang, None) for lang in s_mono] + s_para
def find_modules(module, module_name, module_instance, found):
"""
Recursively find all instances of a specific module inside a module.
"""
if isinstance(module, module_instance):
found.append((module_name, module))
else:
for name, child in module.named_children():
name = ('%s[%s]' if name.isdigit() else '%s.%s') % (module_name, name)
find_modules(child, name, module_instance, found)