Skip to content

Commit

Permalink
[RLlib] Fix bug in ModelCatalog when using custom action distribution (
Browse files Browse the repository at this point in the history
…ray-project#12846)

* return tuple returned from _get_multi_action_distribution when using custom action dict

* Always return dst_class and required_model_output_shape in _get_multi_action_distribution

* pass model config to _get_multi_action_distribution
  • Loading branch information
janblumenkamp authored Jan 25, 2021
1 parent 9423930 commit 964689b
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions rllib/models/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,8 @@ def get_action_dist(
"Using custom action distribution {}".format(action_dist_name))
dist_cls = _global_registry.get(RLLIB_ACTION_DIST,
action_dist_name)
dist_cls = ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, {}, framework)
return ModelCatalog._get_multi_action_distribution(
dist_cls, action_space, config, framework)

# Dist_type is given directly as a class.
elif type(dist_type) is type and \
Expand Down Expand Up @@ -740,7 +740,8 @@ def _get_multi_action_distribution(dist_class, action_space, config,
action_space=action_space,
child_distributions=child_dists,
input_lens=input_lens), int(sum(input_lens))
return dist_class
return dist_class, dist_class.required_model_output_shape(
action_space, config)

@staticmethod
def _validate_config(config: ModelConfigDict, framework: str) -> None:
Expand Down

0 comments on commit 964689b

Please sign in to comment.