-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
WandBUtils.py
122 lines (100 loc) · 3.13 KB
/
WandBUtils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import wandb
import tempfile
import yaml
from functools import lru_cache
class CWBRun:
def __init__(self, runId, api=None, tmpFolder=None):
self._runId = runId
self._api = api or wandb.Api()
self._run = self._api.run(runId)
self._tmpFolder = tmpFolder or tempfile.gettempdir()
return
@property
@lru_cache(maxsize=1)
def config(self):
# load "config.yaml" from files of the run and return it as dict
config = self._run.file('config.yaml')
with config.download(self._tmpFolder, replace=True, exist_ok=True) as f:
res = yaml.safe_load(f)
# fix modifications done by wandb
def fix(value):
if not isinstance(value, dict): return value
isWanDB = ('desc' in value) and ('value' in value)
if isWanDB: return fix(value['value'])
for key, val in value.items():
value[key] = fix(val)
continue
return value
return fix(res)
def models(self):
# return list of models in the run
res = []
for raw in self._run.logged_artifacts():
artifact = CWBFileArtifact(raw, self._tmpFolder, self._run)
if artifact.name.lower().endswith('.h5'):
res.append(artifact)
continue
return res
@property
def bestModel(self):
# find 'model-latest.h5' in the run
models = self.models()
return next(f for f in models if f.name == 'model-latest.h5')
@property
@lru_cache(maxsize=1)
def bestLoss(self):
try:
return min([x['val_loss'] for x in self.history()])
except:
return float('inf') # no history
return
@lru_cache(maxsize=1)
def history(self):
return self._run.scan_history()
@property
def name(self): return self._run.name
@property
def id(self): return self._run.id
@property
def fullId(self): return self._runId
# End of CWBRun
class CWBFileArtifact:
def __init__(self, artifact, tmpFolder, run):
self._artifact = artifact
self._tmpFolder = tmpFolder
self._run = run
return
def pathTo(self):
file = self._run.use_artifact(self._artifact)
return file.file(self._tmpFolder)
@property
def name(self):
res = self._artifact.name
# format: "run-{id}-{name}:{version}"
# we need only name
res = res.split(':')[0] # remove version
res = res.split('-')[2:] # remove "run" and id
return '-'.join(res)
# End of CWBFileArtifact
class CWBProject:
def __init__(self, projectId, api=None, tmpFolder=None):
self._projectId = projectId
self._api = api or wandb.Api()
self._tmpFolder = tmpFolder or tempfile.gettempdir()
return
def runs(self, filters=None):
runs = self._api.runs(self._projectId, filters=filters)
return [CWBRun(self._projectId + '/' + run.id, self._api, self._tmpFolder) for run in runs]
def groups(self, onlyBest=False):
runs = self.runs()
# group runs by name
groups = {}
for run in runs:
name = run.name
if name not in groups: groups[name] = []
groups[name].append(run)
continue
if onlyBest: # select only best runs
groups = {k: min(v, key=lambda x: x.bestLoss) for k, v in groups.items()}
return groups
# End of CWBProject