Skip to content

Commit

Permalink
Make create_df_from_replay_buffer callable in oss by moving related f…
Browse files Browse the repository at this point in the history
…unctions and classes (facebookresearch#423)

Summary:
Pull Request resolved: facebookresearch#423

Move functions `create_df_from_replay_buffer`, `set_seed`, `feature_transform`, and `validate_mdp_ids_seq_nums` from fblearner.flow.projects.rl to reagent, as well as class `ProblemDomain` from reagent.core.fb.parameters to reagent.core.parameters so that oss may call them in unit tests.

Reviewed By: czxttkl

Differential Revision: D27130180

fbshipit-source-id: a06b7e8d5d683bb82a214bdab67b7e7e0ea71f2e
  • Loading branch information
gji1 authored and facebook-github-bot committed Mar 19, 2021
1 parent 82484f7 commit 2cf5f63
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 3 deletions.
11 changes: 11 additions & 0 deletions reagent/core/parameters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import enum
from typing import Dict, List, Optional

from reagent.core.base_dataclass import BaseDataClass
Expand All @@ -18,6 +19,16 @@
CONTINUOUS_TRAINING_ACTION_RANGE = (-1.0, 1.0)


class ProblemDomain(enum.Enum):
CONTINUOUS_ACTION = "continuous_action"
DISCRETE_ACTION = "discrete_action"
PARAMETRIC_ACTION = "parametric_action"

# I don't think the data generated for these 2 types are generic
SEQ_TO_REWARD = "seq2reward"
MDN_RNN = "mdn_rnn"


@dataclass(frozen=True)
class RLParameters(BaseDataClass):
__hash__ = param_hash
Expand Down
267 changes: 264 additions & 3 deletions reagent/gym/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,16 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

import logging
from typing import Dict
import random
from typing import Dict, List, Optional

import gym
import numpy as np
import pandas as pd
import torch # @manual
import torch.nn.functional as F
from gym import spaces
from reagent.core.parameters import NormalizationData, NormalizationKey
from reagent.core.parameters import NormalizationData, NormalizationKey, ProblemDomain
from reagent.gym.agents.agent import Agent
from reagent.gym.agents.post_step import add_replay_buffer_post_step
from reagent.gym.envs import EnvWrapper
Expand All @@ -16,12 +22,14 @@
)
from reagent.gym.policies.random_policies import make_random_policy_for_env
from reagent.gym.runners.gymrunner import run_episode
from reagent.replay_memory.circular_replay_buffer import ReplayBuffer
from reagent.replay_memory import ReplayBuffer
from tqdm import tqdm


logger = logging.getLogger(__name__)

SEED = 0

try:
from reagent.gym.envs import RecSim # noqa

Expand Down Expand Up @@ -144,3 +152,256 @@ def build_normalizer(env: EnvWrapper) -> Dict[str, NormalizationData]:
dense_normalization_parameters=build_action_normalizer(env)
),
}


def create_df_from_replay_buffer(
env: gym.Env,
problem_domain: ProblemDomain,
desired_size: int,
multi_steps: Optional[int],
ds: str,
) -> pd.DataFrame:
# fill the replay buffer
set_seed(env, SEED)
if multi_steps is None:
update_horizon = 1
return_as_timeline_format = False
else:
update_horizon = multi_steps
return_as_timeline_format = True
is_multi_steps = multi_steps is not None

replay_buffer = ReplayBuffer(
replay_capacity=desired_size,
batch_size=1,
update_horizon=update_horizon,
return_as_timeline_format=return_as_timeline_format,
)
fill_replay_buffer(env, replay_buffer, desired_size)

batch = replay_buffer.sample_all_valid_transitions()
n = batch.state.shape[0]
logger.info(f"Creating df of size {n}.")

def discrete_feat_transform(elem) -> str:
""" query data expects str format """
return str(elem.item())

def continuous_feat_transform(elem: List[float]) -> Dict[int, float]:
""" query data expects sparse format """
assert isinstance(elem, torch.Tensor), f"{type(elem)} isn't tensor"
assert len(elem.shape) == 1, f"{elem.shape} isn't 1-dimensional"
return {i: s.item() for i, s in enumerate(elem)}

def make_parametric_feat_transform(one_hot_dim: int):
""" one-hot and then continuous_feat_transform """

def transform(elem) -> Dict[int, float]:
elem_tensor = torch.tensor(elem.item())
one_hot_feat = F.one_hot(elem_tensor, one_hot_dim).float()
return continuous_feat_transform(one_hot_feat)

return transform

state_features = feature_transform(batch.state, continuous_feat_transform)
next_state_features = feature_transform(
batch.next_state,
continuous_feat_transform,
is_next_with_multi_steps=is_multi_steps,
)

if problem_domain == ProblemDomain.DISCRETE_ACTION:
# discrete action is str
action = feature_transform(batch.action, discrete_feat_transform)
next_action = feature_transform(
batch.next_action,
discrete_feat_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal="",
terminal=batch.terminal,
)
elif problem_domain == ProblemDomain.PARAMETRIC_ACTION:
# continuous action is Dict[int, double]
assert isinstance(env.action_space, gym.spaces.Discrete)
parametric_feat_transform = make_parametric_feat_transform(env.action_space.n)
action = feature_transform(batch.action, parametric_feat_transform)
next_action = feature_transform(
batch.next_action,
parametric_feat_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal={},
terminal=batch.terminal,
)
elif problem_domain == ProblemDomain.CONTINUOUS_ACTION:
action = feature_transform(batch.action, continuous_feat_transform)
next_action = feature_transform(
batch.next_action,
continuous_feat_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal={},
terminal=batch.terminal,
)
elif problem_domain == ProblemDomain.MDN_RNN:
action = feature_transform(batch.action, discrete_feat_transform)
assert multi_steps is not None
next_action = feature_transform(
batch.next_action,
discrete_feat_transform,
is_next_with_multi_steps=True,
replace_when_terminal="",
terminal=batch.terminal,
)
else:
raise NotImplementedError(f"model type: {problem_domain}.")

if multi_steps is None:
time_diff = [1] * n
reward = batch.reward.squeeze(1).tolist()
metrics = [{"reward": r} for r in reward]
else:
time_diff = [[1] * len(ns) for ns in next_state_features]
reward = [reward_list.tolist() for reward_list in batch.reward]
metrics = [
[{"reward": r.item()} for r in reward_list] for reward_list in batch.reward
]

# TODO(T67265031): change this to int
mdp_id = [str(i.item()) for i in batch.mdp_id]
sequence_number = batch.sequence_number.squeeze(1).tolist()
# in the product data, all sequence_number_ordinal start from 1.
# So to be consistent with the product data.

sequence_number_ordinal = (batch.sequence_number.squeeze(1) + 1).tolist()
action_probability = batch.log_prob.exp().squeeze(1).tolist()
df_dict = {
"state_features": state_features,
"next_state_features": next_state_features,
"action": action,
"next_action": next_action,
"reward": reward,
"action_probability": action_probability,
"metrics": metrics,
"time_diff": time_diff,
"mdp_id": mdp_id,
"sequence_number": sequence_number,
"sequence_number_ordinal": sequence_number_ordinal,
"ds": [ds] * n,
}

if problem_domain == ProblemDomain.PARAMETRIC_ACTION:
# Possible actions are List[Dict[int, float]]
assert isinstance(env.action_space, gym.spaces.Discrete)
possible_actions = [{i: 1.0} for i in range(env.action_space.n)]

elif problem_domain == ProblemDomain.DISCRETE_ACTION:
# Possible actions are List[str]
assert isinstance(env.action_space, gym.spaces.Discrete)
possible_actions = [str(i) for i in range(env.action_space.n)]

elif problem_domain == ProblemDomain.MDN_RNN:
# Possible actions are List[str]
assert isinstance(env.action_space, gym.spaces.Discrete)
possible_actions = [str(i) for i in range(env.action_space.n)]

# these are fillers, which should have correct shape
pa_features = range(n)
pna_features = time_diff
if problem_domain in (
ProblemDomain.DISCRETE_ACTION,
ProblemDomain.PARAMETRIC_ACTION,
ProblemDomain.MDN_RNN,
):

def pa_transform(x):
return possible_actions

df_dict["possible_actions"] = feature_transform(pa_features, pa_transform)
df_dict["possible_next_actions"] = feature_transform(
pna_features,
pa_transform,
is_next_with_multi_steps=is_multi_steps,
replace_when_terminal=[],
terminal=batch.terminal,
)

df = pd.DataFrame(df_dict)
# validate df
validate_mdp_ids_seq_nums(df)
# shuffling (sample the whole batch)
df = df.reindex(np.random.permutation(df.index))
return df


def set_seed(env: gym.Env, seed: int):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
env.seed(seed)
env.action_space.seed(seed)


def feature_transform(
features,
single_elem_transform,
is_next_with_multi_steps=False,
replace_when_terminal=None,
terminal=None,
):
"""feature_transform is a method on a single row.
We assume features is List[features] (batch of features).
This can also be called for next_features with multi_steps which we assume
to be List[List[features]]. First List is denoting that it's a batch,
second List is denoting that a single row consists of a list of features.
"""
if is_next_with_multi_steps:
if terminal is None:
return [
[single_elem_transform(feat) for feat in multi_steps_features]
for multi_steps_features in features
]
else:
# for next features where we replace them when terminal
assert replace_when_terminal is not None
return [
[single_elem_transform(feat) for feat in multi_steps_features]
if not terminal[idx]
else [single_elem_transform(feat) for feat in multi_steps_features[:-1]]
+ [replace_when_terminal]
for idx, multi_steps_features in enumerate(features)
]
else:
if terminal is None:
return [single_elem_transform(feat) for feat in features]
else:
assert replace_when_terminal is not None
return [
single_elem_transform(feat)
if not terminal[idx]
else replace_when_terminal
for idx, feat in enumerate(features)
]


def validate_mdp_ids_seq_nums(df):
mdp_ids = list(df["mdp_id"])
sequence_numbers = list(df["sequence_number"])
unique_mdp_ids = set(mdp_ids)
prev_mdp_id, prev_seq_num = None, None
mdp_count = 0
for mdp_id, seq_num in zip(mdp_ids, sequence_numbers):
if prev_mdp_id is None or mdp_id != prev_mdp_id:
mdp_count += 1
prev_mdp_id = mdp_id
else:
assert seq_num == prev_seq_num + 1, (
f"For mdp_id {mdp_id}, got {seq_num} <= {prev_seq_num}."
f"Sequence number must be in increasing order.\n"
f"Zip(mdp_id, seq_num): "
f"{list(zip(mdp_ids, sequence_numbers))}"
)
prev_seq_num = seq_num

assert len(unique_mdp_ids) == mdp_count, "MDPs are broken up. {} vs {}".format(
len(unique_mdp_ids), mdp_count
)
return

0 comments on commit 2cf5f63

Please sign in to comment.