Skip to content

Commit

Permalink
add some types to the reward signals (Unity-Technologies#2215)
Browse files Browse the repository at this point in the history
* WIP add some types to the reward signals

* fix next_visual_in

* cleanup TODO

* fix bad merge
  • Loading branch information
chriselion authored Jul 12, 2019
1 parent 603ac97 commit 586e50c
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import List, Tuple
import tensorflow as tf
from mlagents.trainers.models import LearningModel

Expand All @@ -17,12 +18,13 @@ def __init__(
"""
self.encoding_size = encoding_size
self.policy_model = policy_model
self.next_visual_in: List[tf.Tensor] = []
encoded_state, encoded_next_state = self.create_curiosity_encoders()
self.create_inverse_model(encoded_state, encoded_next_state)
self.create_forward_model(encoded_state, encoded_next_state)
self.create_loss(learning_rate)

def create_curiosity_encoders(self):
def create_curiosity_encoders(self) -> Tuple[tf.Tensor, tf.Tensor]:
"""
Creates state encoders for current and future observations.
Used for implementation of Curiosity-driven Exploration by Self-supervised Prediction
Expand Down Expand Up @@ -104,7 +106,9 @@ def create_curiosity_encoders(self):
encoded_next_state = tf.concat(encoded_next_state_list, axis=1)
return encoded_state, encoded_next_state

def create_inverse_model(self, encoded_state, encoded_next_state):
def create_inverse_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor
) -> None:
"""
Creates inverse model TensorFlow ops for Curiosity module.
Predicts action taken given current and future encoded states.
Expand Down Expand Up @@ -142,7 +146,9 @@ def create_inverse_model(self, encoded_state, encoded_next_state):
tf.dynamic_partition(cross_entropy, self.policy_model.mask, 2)[1]
)

def create_forward_model(self, encoded_state, encoded_next_state):
def create_forward_model(
self, encoded_state: tf.Tensor, encoded_next_state: tf.Tensor
) -> None:
"""
Creates forward model TensorFlow ops for Curiosity module.
Predicts encoded future state based on encoded current state and given action.
Expand All @@ -169,7 +175,7 @@ def create_forward_model(self, encoded_state, encoded_next_state):
tf.dynamic_partition(squared_difference, self.policy_model.mask, 2)[1]
)

def create_loss(self, learning_rate):
def create_loss(self, learning_rate: float) -> None:
"""
Creates the loss node of the model as well as the update_batch optimizer to update the model.
:param learning_rate: The learning rate for the optimizer.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from typing import Any, Dict, List
import numpy as np
from mlagents.envs.brain import BrainInfo

from mlagents.trainers.buffer import Buffer
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.components.reward_signals.curiosity.model import CuriosityModel
from mlagents.trainers.tf_policy import TFPolicy
Expand Down Expand Up @@ -33,7 +37,9 @@ def __init__(
}
self.has_updated = False

def evaluate(self, current_info, next_info):
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.
Expand Down Expand Up @@ -75,23 +81,26 @@ def evaluate(self, current_info, next_info):
return RewardSignalResult(scaled_reward, unscaled_reward)

@classmethod
def check_config(cls, config_dict):
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
) -> None:
"""
Checks the config and throw an exception if a hyperparameter is missing. Curiosity requires strength,
gamma, and encoding size at minimum.
"""
param_keys = ["strength", "gamma", "encoding_size"]
super().check_config(config_dict, param_keys)

def update(self, update_buffer, num_sequences):
def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]:
"""
Updates Curiosity model using training buffer. Divides training buffer into mini batches and performs
gradient descent.
:param update_buffer: Update buffer from which to pull data from.
:param num_sequences: Number of sequences in the update buffer.
:return: Dict of stats that should be reported to Tensorboard.
"""
forward_total, inverse_total = [], []
forward_total: List[float] = []
inverse_total: List[float] = []
for _ in range(self.num_epoch):
update_buffer.shuffle()
buffer = update_buffer
Expand All @@ -110,7 +119,9 @@ def update(self, update_buffer, num_sequences):
}
return update_stats

def _update_batch(self, mini_batch, num_sequences):
def _update_batch(
self, mini_batch: Dict[str, np.ndarray], num_sequences: int
) -> Dict[str, float]:
"""
Updates model using buffer.
:param num_sequences: Number of trajectories in batch.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any, Dict, List
import numpy as np
from mlagents.envs.brain import BrainInfo

from mlagents.trainers.buffer import Buffer
from mlagents.trainers.components.reward_signals import RewardSignal, RewardSignalResult
from mlagents.trainers.tf_policy import TFPolicy

Expand All @@ -16,15 +19,19 @@ def __init__(self, policy: TFPolicy, strength: float, gamma: float):
super().__init__(policy, strength, gamma)

@classmethod
def check_config(cls, config_dict):
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
) -> None:
"""
Checks the config and throw an exception if a hyperparameter is missing. Extrinsic requires strength and gamma
at minimum.
"""
param_keys = ["strength", "gamma"]
super().check_config(config_dict, param_keys)

def evaluate(self, current_info, next_info):
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.
Expand All @@ -35,7 +42,7 @@ def evaluate(self, current_info, next_info):
scaled_reward = self.strength * unscaled_reward
return RewardSignalResult(scaled_reward, unscaled_reward)

def update(self, update_buffer, num_sequences):
def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]:
"""
This method does nothing, as there is nothing to update.
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import logging
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.tf_policy import TFPolicy
from typing import Any, Dict, List
from collections import namedtuple
import numpy as np
import abc

import tensorflow as tf

from mlagents.envs.brain import BrainInfo
from mlagents.trainers.trainer import UnityTrainerException
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.buffer import Buffer

logger = logging.getLogger("mlagents.trainers")

RewardSignalResult = namedtuple(
Expand All @@ -32,19 +36,21 @@ def __init__(self, policy: TFPolicy, strength: float, gamma: float):
self.policy = policy
self.strength = strength

def evaluate(self, current_info, next_info):
def evaluate(
self, current_info: BrainInfo, next_info: BrainInfo
) -> RewardSignalResult:
"""
Evaluates the reward for the agents present in current_info given the next_info
:param current_info: The current BrainInfo.
:param next_info: The BrainInfo from the next timestep.
:return: a RewardSignalResult of (scaled intrinsic reward, unscaled intrinsic reward) provided by the generator
"""
return (
return RewardSignalResult(
self.strength * np.zeros(len(current_info.agents)),
np.zeros(len(current_info.agents)),
)

def update(self, update_buffer, n_sequences):
def update(self, update_buffer: Buffer, num_sequences: int) -> Dict[str, float]:
"""
If the reward signal has an internal model (e.g. GAIL or Curiosity), update that model.
:param update_buffer: An AgentBuffer that contains the live data from which to update.
Expand All @@ -54,7 +60,9 @@ def update(self, update_buffer, n_sequences):
return {}

@classmethod
def check_config(cls, config_dict, param_keys=None):
def check_config(
cls, config_dict: Dict[str, Any], param_keys: List[str] = None
) -> None:
"""
Check the config dict, and throw an error if there are missing hyperparameters.
"""
Expand Down
26 changes: 20 additions & 6 deletions ml-agents/mlagents/trainers/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import logging
from typing import Any, Callable, Dict

import numpy as np
import tensorflow as tf
import tensorflow.contrib.layers as c_layers

logger = logging.getLogger("mlagents.trainers")

ActivationFunction = Callable[[tf.Tensor], tf.Tensor]


class LearningModel(object):
_version_number_ = 2
Expand Down Expand Up @@ -83,12 +86,12 @@ def scaled_init(scale):
return c_layers.variance_scaling_initializer(scale)

@staticmethod
def swish(input_activation):
def swish(input_activation: tf.Tensor) -> tf.Tensor:
"""Swish activation function. For more info: https://arxiv.org/abs/1710.05941"""
return tf.multiply(input_activation, tf.nn.sigmoid(input_activation))

@staticmethod
def create_visual_input(camera_parameters, name):
def create_visual_input(camera_parameters: Dict[str, Any], name: str) -> tf.Tensor:
"""
Creates image input op.
:param camera_parameters: Parameters for visual observation from BrainInfo.
Expand Down Expand Up @@ -179,8 +182,13 @@ def create_normalizer_update(self, vector_input):

@staticmethod
def create_vector_observation_encoder(
observation_input, h_size, activation, num_layers, scope, reuse
):
observation_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
num_layers: int,
scope: str,
reuse: bool,
) -> tf.Tensor:
"""
Builds a set of hidden state encoders.
:param reuse: Whether to re-use the weights within the same scope.
Expand All @@ -205,8 +213,14 @@ def create_vector_observation_encoder(
return hidden

def create_visual_observation_encoder(
self, image_input, h_size, activation, num_layers, scope, reuse
):
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
num_layers: int,
scope: str,
reuse: bool,
) -> tf.Tensor:
"""
Builds a set of visual (CNN) encoders.
:param reuse: Whether to re-use the weights within the same scope.
Expand Down

0 comments on commit 586e50c

Please sign in to comment.