Skip to content

Commit

Permalink
load distributed config
Browse files Browse the repository at this point in the history
  • Loading branch information
rparundekar committed May 5, 2024
1 parent 4f1f3b6 commit 42482ee
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 15 deletions.
10 changes: 2 additions & 8 deletions launch.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,14 @@
"""run script for fine-tuning a model."""
import yaml # type: ignore
from aihero.research.config.schema import BatchInferenceJob, TrainingJob
from aihero.research.finetuning.infer import BatchInferenceJobRunner
from aihero.research.finetuning.train import TrainingJobRunner
from fire import Fire


def train(training_config_file: str = "/mnt/config/training/config.yaml", distributed_config_file: str = "") -> None:
def train(training_config_file: str = "/mnt/config/training/config.yaml") -> None:
"""Run Training."""
training_config = TrainingJob.load(training_config_file)
if distributed_config_file:
with open(distributed_config_file) as f:
distributed_training_config = yaml.safe_load(f)
else:
distributed_training_config = ""
TrainingJobRunner(training_config, distributed_config=distributed_training_config).run()
TrainingJobRunner(training_config).run()


def infer(batch_inference_config_file: str = "/mnt/config/batch_inference/config.yaml") -> None:
Expand Down
9 changes: 2 additions & 7 deletions src/aihero/research/finetuning/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Launch the training job inside a container."""
import os
from typing import Any, Optional, Tuple
from typing import Any, Tuple

import torch
from datasets import Dataset, DatasetDict
Expand All @@ -25,14 +25,9 @@
class TrainingJobRunner:
"""Class to run a training job."""

def __init__(self, training_job: TrainingJob, distributed_config: Optional[dict[str, Any]] = None):
def __init__(self, training_job: TrainingJob):
"""Initialize the training job runner."""
self.training_job = training_job
if distributed_config:
self.distributed_config = distributed_config
print("Training LOCAL RANK: {} ...".format(os.getenv("LOCAL_RANK", "Unknown")))
print("Training RANK: {} ...".format(os.getenv("RANK", "Unknown")))
print("Training LOCAL WORLD SIZE: {} ...".format(os.getenv("LOCAL_WORLD_SIZE", "Unknown")))

print("Loading model")
self.model, self.tokenizer = self.load_model()
Expand Down

0 comments on commit 42482ee

Please sign in to comment.