-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathflags.py
83 lines (63 loc) · 2.38 KB
/
flags.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import json
import sys
def read_flags(fn):
with open(fn) as f:
flags = json.loads(f.read())
return flags
def override_with_flags(options, flags, flags_to_use=None):
"""
If `flags_to_use` is None, then override all flags, otherwise,
only consider flags from `flags_to_use`.
"""
if flags_to_use is None:
for k, v in flags.items():
setattr(options, k, v)
else:
for k in flags_to_use:
default_val = options.__dict__.get(k, None)
setattr(options, k, flags.get(k, default_val))
return options
def init_with_flags_file(options, flags_file, flags_to_use=None):
flags = read_flags(flags_file)
options = override_with_flags(options, flags, flags_to_use)
return options
def init_boolean_flags(options, other_args):
raise Exception('Deprecated.')
flags_to_add = {}
# Add a 'no'-prefixed arg for all boolean args.
for k, v in options.__dict__.items():
if isinstance(v, bool):
if k.startswith('no'):
flags_to_add[k[2:]] = not v
else:
flags_to_add['no' + k] = not v
for k, v in flags_to_add.items():
setattr(options, k, v)
# Handled 'no'-prefixed args that were not explicitly defined.
for arg in other_args:
if arg.startswith('--'):
arg = arg[2:]
# Set boolean arg to False
if arg.startswith('no'):
arg = arg[2:]
if hasattr(options, arg) and isinstance(options.__dict__[arg], bool):
options.__dict__[arg] = False
options.__dict__['no' + arg] = True
# Set boolean arg to True
else:
if hasattr(options, arg) and isinstance(options.__dict__[arg], bool):
options.__dict__[arg] = True
options.__dict__['no' + arg] = False
return options
def stringify_flags(options):
# Ignore negative boolean flags.
flags = {k: v for k, v in options.__dict__.items()}
return json.dumps(flags, indent=4, sort_keys=True)
def save_flags(options, experiment_path):
flags = stringify_flags(options)
target_file = os.path.join(experiment_path, 'flags.json')
if options.eval_only_mode:
target_file = os.path.join(experiment_path, 'eval_flags.json')
with open(target_file, 'w') as f:
f.write(flags)