Skip to content

Commit

Permalink
[train] Fix regression where large Trainer attributes get serialized …
Browse files Browse the repository at this point in the history
…along with actor class (ray-project#43234)

Remove the unintentional reference to `self` that gets pickled with the train coordinator function trainable. Also, restructure the code to make it harder to make such a mistake in the future.

---------

Signed-off-by: Justin Yu <[email protected]>
  • Loading branch information
justinvyu authored Feb 16, 2024
1 parent 71d37ff commit e1f39ea
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 28 deletions.
68 changes: 40 additions & 28 deletions python/ray/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import warnings
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type, Union

Expand Down Expand Up @@ -68,6 +69,40 @@ class TrainingFailedError(RuntimeError):
)


def _train_coordinator_fn(
config: dict, trainer_cls: Type["BaseTrainer"], metadata: dict
):
"""This is the function that defines the logic of the Ray Train coordinator.
This is responsible for setting up a remote instance of the `trainer_cls`
(a different instance than the one calling `trainer.fit` on the driver!)
and running the training loop.
"""
assert metadata is not None, metadata
# Propagate user metadata from the Trainer constructor.
_get_session().metadata = metadata

# config already contains merged values.
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)

# Get the checkpoint from Tune and pass it to workers later on.
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Set `starting_checkpoint` for auto-recovery fault-tolerance
# as well as manual restoration.
trainer.starting_checkpoint = checkpoint
# else: Train will restore from the user-provided
# `resume_from_checkpoint` == `starting_checkpoint`.

# Evaluate datasets if they are wrapped in a factory.
trainer.datasets = {
k: d() if callable(d) else d for k, d in trainer.datasets.items()
}

trainer.setup()
trainer.training_loop()


@DeveloperAPI
class BaseTrainer(abc.ABC):
"""Defines interface for distributed training on Ray.
Expand Down Expand Up @@ -656,38 +691,15 @@ def _generate_trainable_cls(self) -> Type["Trainable"]:
scaling_config = self.scaling_config
metadata = self.metadata

def train_func(config):
assert metadata is not None, metadata
# Propagate user metadata from the Trainer constructor.
_get_session().metadata = metadata

# config already contains merged values.
# Instantiate new Trainer in Trainable.
trainer = trainer_cls(**config)

# Get the checkpoint from Tune and pass it to workers later on.
checkpoint = ray.train.get_checkpoint()
if checkpoint:
# Set `starting_checkpoint` for auto-recovery fault-tolerance
# as well as manual restoration.
trainer.starting_checkpoint = checkpoint
# else: Train will restore from the user-provided
# `resume_from_checkpoint` == `starting_checkpoint`.

# Evaluate datasets if they are wrapped in a factory.
trainer.datasets = {
k: d() if callable(d) else d for k, d in self.datasets.items()
}

trainer.setup()
trainer.training_loop()

train_coordinator_fn = partial(
_train_coordinator_fn, trainer_cls=trainer_cls, metadata=metadata
)
# Change the name of the training function to match the name of the Trainer
# class. This will mean the Tune trial name will match the name of Trainer on
# stdout messages and the results directory.
train_func.__name__ = trainer_cls.__name__
train_coordinator_fn.__name__ = trainer_cls.__name__

trainable_cls = wrap_function(train_func)
trainable_cls = wrap_function(train_coordinator_fn)
has_base_dataset = bool(self.datasets)
if has_base_dataset:
from ray.data.context import DataContext
Expand Down
13 changes: 13 additions & 0 deletions python/ray/train/tests/test_base_trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import tempfile

import numpy as np
import pytest

import ray
Expand Down Expand Up @@ -187,6 +188,18 @@ def training_loop(self):
trainer.fit()


def test_large_params(ray_start_4_cpus):
"""Tests that large params are not serialized with the trainer actor
and are instead put into the object store separately."""
huge_array = np.zeros(shape=int(1e8))

def training_loop(self):
huge_array

trainer = DummyTrainer(training_loop)
trainer.fit()


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit e1f39ea

Please sign in to comment.