Skip to content

Commit

Permalink
Use sequential for actor and critic part
Browse files Browse the repository at this point in the history
  • Loading branch information
lcswillems committed Jun 1, 2018
1 parent 93b42ca commit bcb1d1e
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions rl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,18 @@ def __init__(self, obs_space, action_space, use_instr=False, use_memory=False, a
self.embedding_size += self.instr_embedding_size

# Define actor's model
self.a_fc1 = nn.Linear(self.embedding_size, 64)
self.a_fc2 = nn.Linear(64, action_space.n)
self.actor = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, action_space.n)
)

# Define critic's model
self.c_fc1 = nn.Linear(self.embedding_size, 64)
self.c_fc2 = nn.Linear(64, 1)
self.critic = nn.Sequential(
nn.Linear(self.embedding_size, 64),
nn.Tanh(),
nn.Linear(64, 1)
)

# Initialize parameters correctly
self.apply(initialize_parameters)
Expand Down Expand Up @@ -98,14 +104,10 @@ def forward(self, obs, memory):
if self.use_instr:
embedding = torch.cat((embedding, embed_instr), dim=1)

x = self.a_fc1(embedding)
x = F.tanh(x)
x = self.a_fc2(x)
x = self.actor(embedding)
dist = Categorical(logits=F.log_softmax(x, dim=1))

x = self.c_fc1(embedding)
x = F.tanh(x)
x = self.c_fc2(x)
x = self.critic(embedding)
value = x.squeeze(1)

return dist, value, memory
Expand Down

0 comments on commit bcb1d1e

Please sign in to comment.