Skip to content

Commit

Permalink
Fix the attention module embedding size (Unity-Technologies#5272)
Browse files Browse the repository at this point in the history
* Fix the attention module embedding size

* editing the changelog
  • Loading branch information
vincentpierre authored Apr 15, 2021
1 parent 2721989 commit 9ae2c28
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ This results in much less memory being allocated during inference with `CameraSe

#### ml-agents / ml-agents-envs / gym-unity (Python)
- Some console output have been moved from `info` to `debug` and will not be printed by default. If you want all messages to be printed, you can run `mlagents-learn` with the `--debug` option or add the line `debug: true` at the top of the yaml config file. (#5211)
- The embedding size of attention layers used when a BufferSensor is in the scene has been changed. It is now fixed to 128 units. It might be impossible to resume training from a checkpoint of a previous version. (#5272)

### Bug Fixes
#### com.unity.ml-agents / com.unity.ml-agents.extensions (C#)
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/tests/torch/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_create_inputs(encoder_type, normalize, num_vector, num_visual):
h_size = 128
obs_spec = create_observation_specs_with_shapes(obs_shapes)
encoders, embedding_sizes = ModelUtils.create_input_processors(
obs_spec, h_size, encoder_type, normalize
obs_spec, h_size, encoder_type, h_size, normalize
)
total_output = sum(embedding_sizes)
vec_enc = []
Expand Down
31 changes: 21 additions & 10 deletions ml-agents/mlagents/trainers/torch/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@


class ObservationEncoder(nn.Module):
ATTENTION_EMBEDDING_SIZE = 128 # The embedding size of attention is fixed

def __init__(
self,
observation_specs: List[ObservationSpec],
Expand All @@ -45,13 +47,17 @@ def __init__(
"""
super().__init__()
self.processors, self.embedding_sizes = ModelUtils.create_input_processors(
observation_specs, h_size, vis_encode_type, normalize=normalize
observation_specs,
h_size,
vis_encode_type,
self.ATTENTION_EMBEDDING_SIZE,
normalize=normalize,
)
self.rsa, self.x_self_encoder = ModelUtils.create_residual_self_attention(
self.processors, self.embedding_sizes, h_size
self.processors, self.embedding_sizes, self.ATTENTION_EMBEDDING_SIZE
)
if self.rsa is not None:
total_enc_size = sum(self.embedding_sizes) + h_size
total_enc_size = sum(self.embedding_sizes) + self.ATTENTION_EMBEDDING_SIZE
else:
total_enc_size = sum(self.embedding_sizes)
self.normalize = normalize
Expand Down Expand Up @@ -247,6 +253,8 @@ def forward(


class MultiAgentNetworkBody(torch.nn.Module):
ATTENTION_EMBEDDING_SIZE = 128

"""
A network body that uses a self attention layer to handle state
and action input from a potentially variable number of agents that
Expand Down Expand Up @@ -284,13 +292,18 @@ def __init__(
+ sum(self.action_spec.discrete_branches)
+ self.action_spec.continuous_size
)
self.obs_encoder = EntityEmbedding(obs_only_ent_size, None, self.h_size)
self.obs_action_encoder = EntityEmbedding(q_ent_size, None, self.h_size)

self.self_attn = ResidualSelfAttention(self.h_size)
self.obs_encoder = EntityEmbedding(
obs_only_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
)
self.obs_action_encoder = EntityEmbedding(
q_ent_size, None, self.ATTENTION_EMBEDDING_SIZE
)

self.self_attn = ResidualSelfAttention(self.ATTENTION_EMBEDDING_SIZE)

self.linear_encoder = LinearEncoder(
self.h_size,
self.ATTENTION_EMBEDDING_SIZE,
network_settings.num_layers,
self.h_size,
kernel_gain=(0.125 / self.h_size) ** 0.5,
Expand Down Expand Up @@ -337,9 +350,7 @@ def _copy_and_remove_nans_from_obs(
no_nan_obs = []
for obs in single_agent_obs:
new_obs = obs.clone()
new_obs[
attention_mask.bool()[:, i_agent], ::
] = 0.0 # Remoove NaNs fast
new_obs[attention_mask.bool()[:, i_agent], ::] = 0.0 # Remove NaNs fast
no_nan_obs.append(new_obs)
obs_with_no_nans.append(no_nan_obs)
return obs_with_no_nans
Expand Down
14 changes: 9 additions & 5 deletions ml-agents/mlagents/trainers/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,15 @@ def get_encoder_for_obs(
obs_spec: ObservationSpec,
normalize: bool,
h_size: int,
attention_embedding_size: int,
vis_encode_type: EncoderType,
) -> Tuple[nn.Module, int]:
"""
Returns the encoder and the size of the appropriate encoder.
:param shape: Tuples that represent the observation dimension.
:param normalize: Normalize all vector inputs.
:param h_size: Number of hidden units per layer.
:param h_size: Number of hidden units per layer excluding attention layers.
:param attention_embedding_size: Number of hidden units per attention layer.
:param vis_encode_type: Type of visual encoder to use.
"""
shape = obs_spec.shape
Expand All @@ -167,7 +169,7 @@ def get_encoder_for_obs(
EntityEmbedding(
entity_size=shape[1],
entity_num_max_elements=shape[0],
embedding_size=h_size,
embedding_size=attention_embedding_size,
),
0,
)
Expand All @@ -179,14 +181,16 @@ def create_input_processors(
observation_specs: List[ObservationSpec],
h_size: int,
vis_encode_type: EncoderType,
attention_embedding_size: int,
normalize: bool = False,
) -> Tuple[nn.ModuleList, List[int]]:
"""
Creates visual and vector encoders, along with their normalizers.
:param observation_specs: List of ObservationSpec that represent the observation dimensions.
:param action_size: Number of additional un-normalized inputs to each vector encoder. Used for
conditioning network on other values (e.g. actions for a Q function)
:param h_size: Number of hidden units per layer.
:param h_size: Number of hidden units per layer excluding attention layers.
:param attention_embedding_size: Number of hidden units per attention layer.
:param vis_encode_type: Type of visual encoder to use.
:param unnormalized_inputs: Vector inputs that should not be normalized, and added to the vector
obs.
Expand All @@ -200,7 +204,7 @@ def create_input_processors(
embedding_sizes: List[int] = []
for obs_spec in observation_specs:
encoder, embedding_size = ModelUtils.get_encoder_for_obs(
obs_spec, normalize, h_size, vis_encode_type
obs_spec, normalize, h_size, attention_embedding_size, vis_encode_type
)
encoders.append(encoder)
embedding_sizes.append(embedding_size)
Expand All @@ -209,7 +213,7 @@ def create_input_processors(
if x_self_size > 0:
for enc in encoders:
if isinstance(enc, EntityEmbedding):
enc.add_self_embedding(h_size)
enc.add_self_embedding(attention_embedding_size)
return (nn.ModuleList(encoders), embedding_sizes)

@staticmethod
Expand Down

0 comments on commit 9ae2c28

Please sign in to comment.