Skip to content

Commit

Permalink
Fix training_epochs meaning in Pearl offline learning
Browse files Browse the repository at this point in the history
Summary: The parameter `training_epochs` in Pearl's offline learning was being interpreted as the number of batches. This fixes and now it does mean training epochs. The actual number  of batches is computed from the training epochs and replay buffer size, although it is still possible to directly inform the number of batches.

Reviewed By: PavlosApo

Differential Revision: D66943904

fbshipit-source-id: 9568dae67a45464654a12b13012522534df3f787
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed Dec 12, 2024
1 parent 1f50db2 commit c9576ed
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 20 deletions.
22 changes: 18 additions & 4 deletions pearl/utils/functional_utils/train_and_eval/learning_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

# pyre-strict

from typing import Any, Protocol
from typing import Any, Optional, Protocol

from pearl.replay_buffers.transition import TransitionBatch


class LearningLogger(Protocol):
Expand All @@ -12,18 +14,30 @@ class LearningLogger(Protocol):
Args:
results: A dictionary of results.
step: The current step of the learning process.
batch: The batch of data used for the current step.
prefix: A prefix to add to the logged results.
"""

def __init__(self) -> None:
pass

def __call__(self, results: dict[str, Any], step: int, prefix: str = "") -> None:
def __call__(
self,
results: dict[str, Any],
step: int,
batch: Optional[TransitionBatch] = None,
prefix: str = "",
) -> None:
pass


def null_learning_logger(results: dict[str, str], step: int, prefix: str = "") -> None:
def null_learning_logger(
results: dict[str, Any],
step: int,
batch: Optional[TransitionBatch] = None,
prefix: str = "",
) -> None:
"""
A null learning logger that does nothing.
A learning logger that does nothing.
"""
pass
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict

import io
import math
import os
import time

Expand Down Expand Up @@ -141,8 +142,9 @@ def get_offline_data_in_buffer(
def offline_learning(
offline_agent: PearlAgent,
data_buffer: ReplayBuffer,
training_epochs: int = 1000,
logger: LearningLogger = null_learning_logger,
training_epochs: Optional[float] = None,
number_of_batches: Optional[int] = None,
learning_logger: LearningLogger = null_learning_logger,
seed: Optional[int] = None,
) -> None:
"""
Expand All @@ -153,7 +155,12 @@ def offline_learning(
Args:
offline agent (PearAgent): a Pearl agent (typically conservative one such as CQL or IQL).
data_buffer (ReplayBuffer): a replay buffer to sample a batch of transition data.
training_epochs (int): number of sampled batches used for offline learning.
number_of_batches (Optional[int], default 1000): number of batches to sample
from the replay buffer.
Mutually exclusive with training_epochs.
training_epochs (Optional[float], default 1): number of passes over training data.
Fractional values result in a rounded up number of samples batches.
Mutually exclusive with number_of_batches.
logger (LearningLogger, optional): a LearningLogger to log the training loss
(default is no-op logger).
seed (int, optional): random seed (default is `int(time.time())`).
Expand All @@ -162,15 +169,27 @@ def offline_learning(
seed = int(time.time())
set_seed(seed=seed)

if number_of_batches is None:
if training_epochs is None:
training_epochs = 1
number_of_batches = math.ceil(
training_epochs * len(data_buffer) / offline_agent.policy_learner.batch_size
)
elif training_epochs is not None:
raise ValueError(
f"{offline_learning.__name__} must receive at most one of number_of_batches and "
+ "training_epochs, but got both."
)

# move replay buffer to device of the offline agent
data_buffer.device_for_batches = offline_agent.device

# training loop
for i in range(training_epochs):
for i in range(number_of_batches):
batch = data_buffer.sample(offline_agent.policy_learner.batch_size)
assert isinstance(batch, TransitionBatch)
loss = offline_agent.learn_batch(batch=batch)
logger(loss, i, TRAINING_TAG)
learning_logger(loss, i, batch, TRAINING_TAG)


def offline_evaluation(
Expand Down
2 changes: 0 additions & 2 deletions pearl/utils/replay_buffer_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# pyre-strict
from typing import List, Optional, Type

import torch

from pearl.api.action import Action
Expand Down
12 changes: 6 additions & 6 deletions pearl/utils/scripts/benchmark_offline_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def evaluate_offline_rl(
is_action_continuous: bool,
offline_agent: PearlAgent,
method_name: str,
training_epochs: int = 1000,
number_of_batches: int = 1000,
evaluation_episodes: int = 500,
url: str | None = None,
data_path: str | None = None,
Expand All @@ -137,7 +137,7 @@ def evaluate_offline_rl(
compatible with a Pearl Agent (see class TensorBasedReplayBuffer for details).
offline_agent: an offline agent to train and evaluate (for example, IQL or CQL based agent).
method_name: name of the agent's policy learner (used for saving results).
training_epochs: number of epochs to train the offline agent for.
number_of_batches: number of batches sampled to train the offline agent for.
evaluation_episodes: number of episodes to evaluate the offline agent for.
url: url to download data from.
data_path: path to a local file containing offline data to use for training.
Expand Down Expand Up @@ -196,8 +196,8 @@ def evaluate_offline_rl(
offline_learning(
offline_agent=offline_agent,
data_buffer=offline_data_replay_buffer,
training_epochs=training_epochs,
seed=seed,
number_of_batches=number_of_batches,
seed=seed if seed is not None else 0,
)

print("\n")
Expand All @@ -224,7 +224,7 @@ def evaluate_offline_rl(
+ "returns_offline_agent_"
+ dataset
+ "_"
+ str(training_epochs)
+ str(number_of_batches)
+ ".pickle",
"wb",
) as handle:
Expand Down Expand Up @@ -310,7 +310,7 @@ def evaluate_offline_rl(
is_action_continuous=is_action_continuous,
offline_agent=offline_agent,
method_name="Implicit Q learning",
training_epochs=10000,
number_of_batches=10000,
# data_path=data_path,
data_collection_agent=data_collection_agent,
file_name=file_name,
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,7 @@ def test_cql_offline_training(self) -> None:
offline_learning(
offline_agent=conservativeDQN_agent,
data_buffer=offline_data_replay_buffer,
training_epochs=2000,
number_of_batches=2000,
seed=100,
)

Expand Down Expand Up @@ -997,7 +997,7 @@ def test_iql_offline_training(self) -> None:
offline_learning(
offline_agent=IQLAgent,
data_buffer=offline_data_replay_buffer,
training_epochs=2000,
number_of_batches=2000,
seed=100,
)

Expand Down
2 changes: 1 addition & 1 deletion test/unit/with_pytorch/test_offline_cql.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_create_offline_data_and_learn_cql(self) -> None:
offline_learning(
offline_agent=conservative_agent,
data_buffer=offline_data_replay_buffer,
training_epochs=10,
number_of_batches=10,
seed=100,
)

Expand Down

0 comments on commit c9576ed

Please sign in to comment.