Skip to content

Commit

Permalink
kleine Anpassungen
Browse files Browse the repository at this point in the history
  • Loading branch information
aws1313 committed Feb 2, 2024
1 parent 2e927dc commit 5695823
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions dqn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn as nn
import torch.optim as optim


class DQN(nn.Module):
def __init__(self, inp_size: int, hidden_size, out_size: int, lr, gamma, device):
super(DQN, self).__init__()
Expand All @@ -24,20 +25,22 @@ def __init__(self, inp_size: int, hidden_size, out_size: int, lr, gamma, device)
self.flatten = nn.Flatten()
self.run_device = device


def forward(self, x):
if len(x.shape)==1:
x=x.unsqueeze(0)
if len(x.shape) == 1:
x = x.unsqueeze(0)
x = self.flatten(x)
return self.rls(x)

def save(self, name):
torch.save(self.state_dict(), name)

def train_step(self, old_state, new_state, action, reward):
# Da wir uns nicht sicher waren, ob unsere ursprüngliche Implementierung funktioniert hat, haben wir uns
# an diesem Code von dem Spiel Snake orientiert und ihn für uns angepasst:
# https://github.com/patrickloeber/snake-ai-pytorch/blob/main/model.py
old_state = torch.from_numpy(old_state).to(self.run_device)
new_state = torch.from_numpy(new_state).to(self.run_device)
action = torch.tensor(action, dtype = torch.float).to(self.run_device)
action = torch.tensor(action, dtype=torch.float).to(self.run_device)
reward = torch.tensor(reward, dtype=torch.float).to(self.run_device)
if len(old_state.shape) == 1:
old_state = torch.unsqueeze(old_state, 0)
Expand Down

0 comments on commit 5695823

Please sign in to comment.