Skip to content

Commit

Permalink
POCA Attention will use h_size for embedding size and not 128 (Unity-…
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentpierre authored Apr 19, 2021
1 parent d06488d commit 8c01b76
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,6 @@ def forward(


class MultiAgentNetworkBody(torch.nn.Module):
ATTENTION_EMBEDDING_SIZE = 128

"""
A network body that uses a self attention layer to handle state
and action input from a potentially variable number of agents that
Expand Down Expand Up @@ -293,17 +291,18 @@ def __init__(
+ self.action_spec.continuous_size
)

attention_embeding_size = self.h_size
self.obs_encoder = EntityEmbedding(
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
obs_only_ent_size, None, attention_embeding_size
)
self.obs_action_encoder = EntityEmbedding(
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
q_ent_size, None, attention_embeding_size
)

self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE)
self.self_attn = ResidualSelfAttention(attention_embeding_size)

self.linear_encoder = LinearEncoder(
self.ATTENTION_EMBEDDING_SIZE,
attention_embeding_size,
network_settings.num_layers,
self.h_size,
kernel_gain=(0.125 / self.h_size) ** 0.5,
Expand Down

0 comments on commit 8c01b76

Please sign in to comment.