Skip to content

Commit

Permalink
V-prop is done with shift funciton now - works!!
Browse files Browse the repository at this point in the history
this version works with shift function

- not very fast, need more batching probably
  • Loading branch information
adamasb committed Dec 1, 2022
1 parent 325a840 commit 79516ba
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 113 deletions.
Binary file modified src/mazeenv/__pycache__/maze_register.cpython-39.pyc
Binary file not shown.
Binary file modified src/raya3c/__pycache__/my_callback.cpython-39.pyc
Binary file not shown.
159 changes: 136 additions & 23 deletions src/raya3c/example_vin.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,25 @@ def VP_nn(self,s,phi,K=10):

#can also be defined based on phi
h, w = phi[:, :, 0].shape[0], phi[:, :, 0].shape[1] #height and width of map



#consider initializing this as a minimum of something, so that its not just 0s. Force it to calculate some gradients
v = torch.zeros((h,w,K+1)) # wanna pad or roll over this, i think

""" IF I WANT TO TRY TO CREATE A SHIFT FUNCTION
for tt in range(self.K)
vm = []
vm.append(v)
# v_pad = pad v (dont pad with 0, but veeeery negative number)
for ax, shift in [ (1,-1), (1,1), (2,-1), (2,1)]:
v_shift = torch.roll(v_pad, dims=ax-0,shifts=shift)[:,1:-1]
r_shift = torch.roll(r_in_pad, dims=ax-0,shifts=shift)[:,1:-1]
vn.append(p[:,:,:] * v_shift + r_shift -rout[:,:,:])
v,_ = torch.stack(vm).max(dim=0)
"""


for k in range(K): #number of "convolutions", or times the algorithm is applied
for i in range(h):
for j in range(w):
Expand Down Expand Up @@ -177,7 +193,48 @@ def VP_nn(self,s,phi,K=10):
return dim4
# return vp

def VP_new(self,s,phi,K=10):

try:
p,rin,rout = torch.unsqueeze(phi[:,:,0],dim=2), torch.unsqueeze(phi[:,:,1],dim=2), torch.unsqueeze(phi[:,:,2],dim=2)

v = torch.zeros((p.shape[0],p.shape[1],1)) # wanna pad or roll over this, i think

r_in_pad = torch.nn.functional.pad(rin, (1,1) + (1,1), 'constant', 0)


""" Tues code starts here"""
for tt in range(K):
vm = []
if tt > 0:
vm.append(v)

v_pad = torch.nn.functional.pad(v, (1,1) + (1,1), 'constant', 0)
for ax, shift in [ (1, -1), (1, 1), (2, -1), (2, 1)]:
# torch.pad(v, )
v_shift = torch.roll(v_pad, dims=ax-0, shifts=shift)[:,1:-1, 1:-1]
r_shift = torch.roll(r_in_pad, dims=ax-0, shifts=shift)[:,1:-1, 1:-1]
vm.append(p[:,:,:] * v_shift + r_shift - rout[:,:,:] )


self.debug_vin = False
v, _ = torch.stack(vm).max(axis=0)
if self.debug_vin:
# diff = np.abs( V_db[:,:,:,tt+1] - v.detach().numpy() ).sum() #only used for comparing with the "bad VP"
diff = np.abs( self.VP_simple(s) - v.detach().numpy() ).sum() #this isnt actually worth anything yet (fix)
if diff > 1e-4:
print(tt, "large diff", diff)
assert False

# II = oa.nonzero().numpy()
# v_pad = torch.nn.functional.pad(v, (1, 1) + (1, 1), 'constant', 0)
# o_pad = torch.nn.functional.pad(o, (0,0) + (1, 1) + (1, 1), 'constant', 0)
# self.v_raw = v
"""" tues code^^"""
except:
return 1

return v

def get_neighborhood(self, obs,dim4,a_index):

Expand Down Expand Up @@ -210,8 +267,12 @@ def get_neighborhood(self, obs,dim4,a_index):
return torch.stack(neighborhood)


def forward(self, input_dict, state, seq_lens): #dont think this is currently being used
obs = input_dict["obs_flat"]
def forward(self, input_dict,state,seq_lens):
try:
obs = input_dict["obs"]
except:
obs = input_dict #this is for when using the get_env_state_value_function method thing


# Store last batch size for value_function output.
self._last_batch_size = obs.shape[0]
Expand All @@ -222,53 +283,105 @@ def forward(self, input_dict, state, seq_lens): #dont think this is currently be
a_index = []

v_c = []
v_new = []
""" consider throwing tues new scrab (shitft) functions into here"""

"""
for tt in range(self.K):
vm = []
if tt > 0:
vm.append(v)
v_pad = torch.nn.functional.pad(v, (1,1) + (1,1), 'constant', 0)
for ax, shift in [ (1, -1), (1, 1), (2, -1), (2, 1)]:
# torch.pad(v, )
v_shift = torch.roll(v_pad, dims=ax-0, shifts=shift)[:,1:-1, 1:-1]
r_shift = torch.roll(r_in_pad, dims=ax-0, shifts=shift)[:,1:-1, 1:-1]
vm.append(p[:,:,:] * v_shift + r_shift - rout[:,:,:] )
v, _ = torch.stack(vm).max(axis=0)
if self.debug_vin:
diff = np.abs( V_db[:,:,:,tt+1] - v.detach().numpy() ).sum()
if diff > 1e-4:
print(tt, "large diff", diff)
assert False
II = oa.nonzero().numpy()
v_pad = torch.nn.functional.pad(v, (1, 1) + (1, 1), 'constant', 0)
o_pad = torch.nn.functional.pad(o, (0,0) + (1, 1) + (1, 1), 'constant', 0)
self.v_raw = v
"""


for ii in range(obs.shape[0]):

phi.append(self.Phi(input_dict["obs"][ii].squeeze())) #only use the first obs, as it is the same for all (for now)
phi.append(self.Phi(obs[ii].squeeze())) #only use the first obs, as it is the same for all (for now)
# fixes issue of overriding tensor with gradients
phi_vals = phi[ii].detach().numpy() #convert to np array to remove gradients

width = len(input_dict["obs"][0][:,:,0])
#generalize dimensions
dim4.append(self.VP_nn(obs[ii].reshape((width,width,3)),phi_vals))


# Not used after implementing shift-vprop-function
#probably dont want to detach this
# phi_vals = phi[ii].detach().numpy() #convert to np array to remove gradients
# width = len(obs[0][:,:,0])
# dim4.append(self.VP_nn(obs[ii].reshape((width,width,3)),phi_vals))
#dim4.append(self.VP_nn(obs[ii].reshape((width,width,3)),phi[ii])) #trying without the detach.numpu


v_new.append(self.VP_new(obs[-1],phi[-1]))

if obs[ii].any() !=0:
a_index.append(input_dict["obs"][ii][:,:,1].nonzero().detach().numpy()[0]) #get the index of the agent
a_index.append(obs[ii][:,:,1].nonzero().detach().numpy()[0]) #get the index of the agent
else:
a_index.append([0,0]) #this doesnt really matter

v_c.append(dim4[ii][a_index[ii][0],a_index[ii][1]])
# v_c.append(dim4[ii][a_index[ii][0],a_index[ii][1]])
v_c.append(v_new[ii][a_index[ii][0],a_index[ii][1]])
# V_np = []
# assert( (V_np - V_torch.numpy())< 1e-8 )
pass




# self.value_cache = tensor of size B x 1 corresponding to v[b, I, J], b = [0, 1, 2, 3, ..., B]
self.value_cache = torch.stack(v_c)


self._last_flat_in = self.get_neighborhood(input_dict["obs"],dim4,a_index)
self._last_flat_in = self.get_neighborhood(input_dict["obs"],v_new,a_index)
# self._last_flat_in = obs.reshape(obs.shape[0], -1)
self._features = self._hidden_layers(self._last_flat_in)
logits = self._logits(self._features) if self._logits else self._features

#logits = self.nn(vp)
return logits, state #from fcnet, state is []
return logits, []# state #from fcnet, state is []


def value_function(self): #dont think this is currently being used
#consider pass value function through a neural network
#return self.value_cache #slight formatting issue
return self._value_branch(self._features).squeeze(1) #torch.Size([32])
return self.value_cache.squeeze(1)
# return self._value_branch(self._features).squeeze(1) #torch.Size([32])



def value_function_for_env_state(self, env_state):

obs = env_state[np.newaxis, :]
info_dict = torch.tensor(env_state)# {"obs": torch.tensor(env_state)}
#self.forward(info_dict, [],[])

phi_vals = self.Phi(info_dict.float()).detach().numpy() #convert to np array to remove gradients

# v.append(self.VP_nn(obs[0].reshape((info_dict[:,:,0].shape[0],info_dict[:,:,0].shape[1],3)),phi_vals))
v = self.VP_nn(obs[0].reshape((info_dict[:,:,0].shape[0],info_dict[:,:,0].shape[1],3)),phi_vals).detach().numpy()
# = np.zeros((4,4))
p = phi_vals[:,:,0]
rin = phi_vals[:,:,1]
rout = phi_vals[:,:,2]
stats = {"v": v.squeeze(), "phi": phi_vals, "p": p, "rin": rin, "rout": rout}
return stats



#pi from agent.py
def pi(self, s, k=None): #we never enter this (except with the irlc-visualise stuff)
# return self.env.action_space.sample() #!s
return self.obs_space.action_space.sample()


#Tue's "Scrap file" code:
Expand Down Expand Up @@ -324,7 +437,7 @@ def my_pol_fun(a3policy, *args, a3c=None, **kwargs):

#this gives me an error: functools is not defined
import functools #maybe this is it? -> causes all kinds of weird problems, leave it out for now
#worker_set.foreach_policy(functools.partial(my_pol_fun, a3c=a3c) )
worker_set.foreach_policy(functools.partial(my_pol_fun, a3c=a3c) )

return dict(my_eval_metric=123)

Expand Down
15 changes: 13 additions & 2 deletions src/raya3c/my_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def on_algorithm_init(self, *args, algorithm=None):
wandb.config = {
"learning_rate": 0.001,
"epochs": 100,
"batch_size": 128
"batch_size": 128,
"dir": "/tmp" # not sure if this will fix
}
# for loss in range(10):
# wandb.log({"loss": np.sqrt(loss)})
Expand Down Expand Up @@ -154,8 +155,18 @@ def evaluation_call(self, a3policy):
image_array = [i*255 for i in image_array]
import PIL
images = [PIL.Image.fromarray(image.astype(np.uint8)) for image in image_array]

# images = self.wandb.Image(image_array, caption="Top: Output, Bottom: Input")
self.wandb.log({"Layout (Green=agent) | V | V_norm | p | rin | rout ": [self.wandb.Image(image) for image in images]})
# for ii in range(len(images)):
try:
#self.wandb.log({"examples": [self.wandb.Image(image) for image in images]})
self.wandb.log({"layout, v, v_norm, p, rin, rout ": [self.wandb.Image(image) for image in images]})
except Exception as e:
print(e) #hopefully we dont get here too often
for ii, image in enumerate(images):
self.wandb.log({f"image {ii}": self.wandb.Image(image)})

#self.wandb.log({"Layout (Green=agent) | V | V_norm | p | rin | rout ": [self.wandb.Image(image) for image in images]})
# import torchvision
pass

Expand Down
Loading

0 comments on commit 79516ba

Please sign in to comment.