forked from MurpheyLab/MaxDiffRL
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
80 lines (73 loc) · 2.57 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#!/usr/bin/env python3
import numpy as np
import datetime
import time
from termcolor import cprint
import yaml
import torch
import os
import torch
from typing import List
def get_duration(start_time,print_update=True):
duration_str = str(datetime.timedelta(seconds=(time.time()-start_time)))
if print_update: cprint('runtime: '+duration_str,'magenta')
return duration_str
def _batch_mv(mat: torch.Tensor,vec: torch.Tensor):
'''assumes batch dim=0'''
return (mat @ vec.unsqueeze(-1)).squeeze(-1)
def jit_prob_dist(
prob_dist: torch.distributions.distribution.Distribution,
) -> torch.nn.Module:
class JITProbDist(torch.nn.Module):
@torch.jit.ignore
def single_sample(self):
return prob_dist.sample()
@torch.jit.ignore
def sample(self,x:List[int]):
return prob_dist.sample(x)
@torch.jit.ignore
def log_prob(self,x):
return prob_dist.log_prob(x)
def extra_repr(self):
return prob_dist.__repr__()
return JITProbDist()
def obs_to_np(time_step,subset=True):
if subset:
states = ['origin','torso_velocity','torso_upright',
'imu','force_torque','rangefinder']
obs = [time_step.observation[key] for key in states]
ego_state = time_step.observation['egocentric_state']
obs.append(ego_state)
else:
obs = list(time_step.observation.values())
obs = np.hstack(obs)
return obs
def save_config(args,config,env_name):
now = datetime.datetime.now()
date_str = now.strftime("%Y-%m-%d_%H-%M-%S/")
dir_name = 'seed_{}/'.format(str(args.seed))
if args.pointmass:
mod = '_beta'+ '{:0.0e}'.format(args.beta).replace('+','').replace('-','_')
else:
mod = ''
path = args.base_dir + args.method + '/' + env_name + mod + config['name_mod'] + '/' + dir_name
if os.path.exists(path) == False:
os.makedirs(path)
# save config yaml
with open(path+'/config.yaml', 'w') as f:
yaml.safe_dump(config,f)
# save config
cprint(path,'red')
with open(path + "/../config.txt","a") as f:
f.write('\n'+ date_str )
f.write('\nArgparse\n')
f.write('\t' + str(args))
f.write('\nConfig\n')
for key,value in config.items():
f.write('\t' + str(key) + '\t' + str(value) + '\n')
f.write('\Torch nConfig\n')
f.write(torch.__config__.parallel_info())
f.close()
with open(path + "/rewards.txt","a") as f:
f.write('{}\t{}\t{}\t{}\n'.format('ep_num','episode_reward','step','ep_time'))
return path