forked from YeWR/EfficientZero
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest_neurons.py
142 lines (123 loc) · 7.19 KB
/
test_neurons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
import argparse
import os
import random
from config.atari import game_config
from config.atari.heatmap_model import EfficientZeroNetH
from core.test_write_trace import test_write_trace
from core.game import GameHistory
def one_hot(n_neurons, n, device):
temp = torch.zeros(n_neurons).to(device)
temp[n]+=1
return temp
#test input: python test_neurons.py --env MsPacmanNoFrameskip-v4 --model_path ./model/model.p --sae_paths ./sae/test_sae.p --random_features=5
if __name__ == '__main__':
#gather arguments related to testing:
parser = argparse.ArgumentParser(description='EfficientZero')
parser.add_argument('--env', required=True, help='Name of the gym environment')
parser.add_argument('--test_episodes', type=int, default=1, help='Evaluation episode count (default: %(default)s)')
parser.add_argument('--model_path', type=str, default='./results/model.p', help='load model path')
parser.add_argument('--sae_paths', nargs='+', type=str, default=['./results/sae.p'], help='load some autoencoder paths')
parser.add_argument('--sae_layers', nargs='+', type=str, default=['p6'], help='layers for the loaded autoencoders. Currently recommended: p3, p6, q0, q3, q4, v3, v4')
parser.add_argument('--results_path', type=str, default='./results/test_neurons', help='save clips directory')
parser.add_argument('--device', type=str, default='cpu', help='cpu or cuda')
parser.add_argument('--save_neurons', action='store_true', help='Flag: Attribute neurons if no autoencoder at file')
parser.add_argument('--random_features', type=int, default=0, help='Number of random features to attribute (default: %(default)s)')
parser.add_argument('--features', nargs='+', type=int, default=[], help='Specific features to attribute if possible')
parser.add_argument('--feature_source', type=str, default='encoder', help='pick encoder or decoder - only different if weights are untied')
args = parser.parse_args()
assert os.path.exists(args.model_path), 'model not found at {}'.format(args.model_path)
for p in args.sae_paths:
if args.save_neurons:
if not os.path.exists(p):
print('model not found at {}, saving neurons'.format(p))
else:
assert os.path.exists(p), 'autoencoder not found at {}'.format(p)
device = args.device
print(args)
#set configs
#implied_args = type('', (object,),{"env":"BreakoutNoFrameskip-v4", "case":"atari","opr":"test","amp_type":"torch_amp","render":False,"seed":0,"use_priority":False,"use_max_priority":False,"debug":False,"device":'cpu',"cpu_actor":4,"gpu_actor":4,"p_mcts_num":4,"use_root_value":False,"use_augmentation":False,"revisit_policy_search_rate":0.99,"result_dir":"./","info":"none"})()
implied_args = type('', (object,),{"env":args.env,
"case":"atari",
"opr":"test",
"amp_type":"torch_amp",
"render":False,
"seed":0,
"use_priority":False,
"use_max_priority":False,
"debug":False,
"device":device,
"cpu_actor":4,
"gpu_actor":4,
"p_mcts_num":4,
"use_root_value":False,
"use_augmentation":False,
"revisit_policy_search_rate":0.99,
"result_dir":args.results_path,
"info":"none"
})()
exp_path = game_config.set_config(implied_args)
#initialize model with a custom model that has internal variables and a heatmap method
model = EfficientZeroNetH(
game_config.obs_shape,
game_config.action_space_size,
game_config.blocks,
game_config.channels,
game_config.reduced_channels_reward,
game_config.reduced_channels_value,
game_config.reduced_channels_policy,
game_config.resnet_fc_reward_layers,
game_config.resnet_fc_value_layers,
game_config.resnet_fc_policy_layers,
game_config.reward_support.size,
game_config.value_support.size,
game_config.downsample,
game_config.inverse_value_transform,
game_config.inverse_reward_transform,
game_config.lstm_hidden_size,
bn_mt=game_config.bn_mt,
proj_hid=game_config.proj_hid,
proj_out=game_config.proj_out,
pred_hid=game_config.pred_hid,
pred_out=game_config.pred_out,
init_zero=game_config.init_zero,
state_norm=game_config.state_norm).to(device)
#load model
model.load_state_dict(torch.load(args.model_path, map_location=torch.device(device)), strict=False)
#assemble set of feature numbers to test
feature_nums = set()
#TODO: pick a max number based on actual size of layers rather than a hardcoded constant?
n_features = 1024
feature_nums.update(random.sample(range(0,n_features),args.random_features))
feature_nums.update(args.features)
#encoders is the set of names of autoencoders - the layer (e.g. p6), followed by the file name
autoencoders = set()
#Load autoencoder features and make a dict
ae_features = {}
for n in range(len(args.sae_paths)):
ae_name = args.sae_layers[n] + ''.join(filter(str.isalnum,args.sae_paths[n]))
autoencoders.add(ae_name)
#store an autoencoder's offset bias and features in a dict
ae_features[ae_name]={}
if os.path.exists(args.sae_paths[n]):
#load state_dict
sd = torch.load(p, map_location=torch.device(device))
ae_features[ae_name]['offset'] = sd['_post_decoder_bias._bias_reference']
#get the desired features
#TODO: test compatibility iwth convolutional layers
for m in feature_nums:
if m < sd['_encoder._weight'].size()[0]:
if args.feature_source == 'decoder':
ae_features[ae_name][m] = (sd['_decoder._weight'][:,m], 0)
elif args.feature_source == 'encoder':
ae_features[ae_name][m] = (sd['_encoder._weight'][m,:], sd['_encoder._bias'][m])
else:
raise ValueError('Pick encoder or decoder')
else:
#construct one-hot features at the index given by the feature numbers
#TODO: check compatibility with convolutional layers
ae_features[ae_name]['offset'] = torch.zeros(model.features(args.sae_layers[n])).to(device)
for m in feature_nums:
ae_features[ae_name][m] = (one_hot(model.features(args.sae_layers[n]), m, device),0)
#call test_write_trace to do the thing
test_write_trace(game_config, model, args.test_episodes, device, autoencoders, ae_features, feature_nums, use_pb=True)