Skip to content

Commit

Permalink
Global imports (Lightning-AI#1170)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Mar 20, 2024
1 parent 12da7e7 commit 9b6475d
Show file tree
Hide file tree
Showing 47 changed files with 237 additions and 517 deletions.
5 changes: 3 additions & 2 deletions litgpt/data/alpaca.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from typing import Optional, Union

import torch
from torch.utils.data import random_split, DataLoader
from lightning_utilities.core.imports import RequirementCache
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from torch.utils.data import DataLoader, random_split

from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.prompts import PromptStyle
from litgpt.tokenizer import Tokenizer

Expand Down
3 changes: 2 additions & 1 deletion litgpt/data/alpaca_2k.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from dataclasses import dataclass, field
from pathlib import Path
from litgpt.data.alpaca import Alpaca

from litgpt.data import SFTDataset
from litgpt.data.alpaca import Alpaca


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion litgpt/data/alpaca_gpt4.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from dataclasses import dataclass, field
from pathlib import Path
from litgpt.data.alpaca import Alpaca

from litgpt.data.alpaca import Alpaca

_URL = "https://raw.githubusercontent.com/Instruction-Tuning-with-GPT-4/GPT-4-LLM/main/data/alpaca_gpt4_data.json"

Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from abc import abstractmethod
from functools import partial
from typing import List, Dict, Union, Optional, Callable, Any
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from lightning import LightningDataModule
from torch import Tensor
from torch.utils.data import Dataset

from lightning import LightningDataModule
from litgpt import Tokenizer
from litgpt.prompts import PromptStyle

Expand Down
5 changes: 2 additions & 3 deletions litgpt/data/deita.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
from pathlib import Path
from dataclasses import dataclass, field

from typing import Optional, List, Union
from pathlib import Path
from typing import List, Optional, Union

import torch
from torch.utils.data import DataLoader
Expand Down
2 changes: 1 addition & 1 deletion litgpt/data/dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import random_split

from litgpt import PromptStyle
from litgpt.data import SFTDataset, Alpaca
from litgpt.data import Alpaca, SFTDataset

_URL: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl"

Expand Down
4 changes: 2 additions & 2 deletions litgpt/data/flan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Dict, List, Set, Union
from typing import Dict, List, Optional, Set, Union

import torch
from torch.utils.data import DataLoader

from litgpt import PromptStyle
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.data.alpaca import download_if_missing
from litgpt.tokenizer import Tokenizer

Expand Down
6 changes: 3 additions & 3 deletions litgpt/data/json_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import json
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Union, Tuple, Any
from typing import Any, Optional, Tuple, Union

import torch
from torch.utils.data import random_split, DataLoader
from torch.utils.data import DataLoader, random_split

from litgpt import PromptStyle
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.tokenizer import Tokenizer


Expand Down
5 changes: 2 additions & 3 deletions litgpt/data/lima.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@
"""Implementation derived from https://github.com/tloen/alpaca-lora"""
import os
from dataclasses import dataclass, field

from typing import Optional, List, Union
from typing import List, Optional, Union

import torch
from torch.utils.data import random_split, DataLoader
from torch.utils.data import DataLoader, random_split

from litgpt import PromptStyle
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
Expand Down
2 changes: 1 addition & 1 deletion litgpt/data/lit_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union, Optional, Tuple
from typing import Optional, Tuple, Union

from torch.utils.data import DataLoader

Expand Down
3 changes: 1 addition & 2 deletions litgpt/data/longform.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from torch.utils.data import DataLoader

from litgpt import PromptStyle
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
from litgpt.data.alpaca import download_if_missing
from litgpt.tokenizer import Tokenizer


_URL = "https://raw.githubusercontent.com/akoksal/LongForm/main/dataset"


Expand Down
2 changes: 1 addition & 1 deletion litgpt/data/openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Union, Optional
from typing import Optional, Union

from torch.utils.data import DataLoader

Expand Down
2 changes: 1 addition & 1 deletion litgpt/data/prepare_slimpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pathlib import Path

from litgpt import Tokenizer
from litgpt.utils import CLI
from litgpt.data.prepare_starcoder import DataChunkRecipe
from litgpt.utils import CLI


class SlimPajamaDataRecipe(DataChunkRecipe):
Expand Down
2 changes: 1 addition & 1 deletion litgpt/data/tinyllama.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union, Optional
from typing import Optional, Union

from torch.utils.data import DataLoader

Expand Down
4 changes: 2 additions & 2 deletions litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
copy_config_files,
num_parameters,
parse_devices,
reset_parameters,
save_config,
save_hyperparameters,
reset_parameters,
)


Expand Down Expand Up @@ -153,7 +153,7 @@ def main(
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
model = GPT(config)

initialize_weights(fabric, model, n_layer=config.n_layer, n_embd=config.n_embd)

if train.tie_embeddings:
Expand Down
9 changes: 0 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,13 @@

import os
import shutil
import sys
from pathlib import Path
from typing import List

import pytest
import torch
from lightning.fabric.utilities.testing import _runif_reasons

wd = Path(__file__).parent.parent.absolute()


@pytest.fixture(autouse=True)
def add_wd_to_path():
# this adds support for running tests without the package installed
sys.path.append(str(wd))


@pytest.fixture()
def fake_checkpoint_dir(tmp_path):
Expand Down
5 changes: 2 additions & 3 deletions tests/data/test_alpaca.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from litgpt.data import Alpaca
from litgpt.prompts import Alpaca as AlpacaPromptStyle


def test_alpaca(mock_tokenizer, alpaca_path):
from litgpt.data import Alpaca
from litgpt.prompts import Alpaca as AlpacaPromptStyle

alpaca = Alpaca(val_split_fraction=0.5, download_dir=alpaca_path.parent, file_name=alpaca_path.name, num_workers=0)
assert isinstance(alpaca.prompt_style, AlpacaPromptStyle)
alpaca.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
Expand Down
12 changes: 4 additions & 8 deletions tests/data/test_base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from unittest.mock import Mock

import pytest
import torch

from litgpt.data import SFTDataset
from litgpt.data import get_sft_collate_fn
from litgpt.prompts import PromptStyle


@pytest.mark.parametrize("mask_prompt", [True, False])
@pytest.mark.parametrize("ignore_index", [-1, -100])
@pytest.mark.parametrize("max_seq_length", [1000, 5])
def test_sft_dataset(max_seq_length, ignore_index, mask_prompt, mock_tokenizer):
from litgpt.data import SFTDataset
from litgpt.prompts import PromptStyle

class Style(PromptStyle):
def apply(self, prompt, **kwargs):
return f"In: {prompt} Out:"
Expand Down Expand Up @@ -42,8 +42,6 @@ def apply(self, prompt, **kwargs):
@pytest.mark.parametrize("ignore_index", [-1, -100])
@pytest.mark.parametrize("pad_id", [0, 100])
def test_sft_collate_fn_padding(pad_id, ignore_index):
from litgpt.data import get_sft_collate_fn

collate = get_sft_collate_fn(pad_id=pad_id, ignore_index=ignore_index)
samples = [
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30])},
Expand All @@ -58,8 +56,6 @@ def test_sft_collate_fn_padding(pad_id, ignore_index):


def test_sft_collate_fn_truncation():
from litgpt.data import get_sft_collate_fn

collate = get_sft_collate_fn(max_seq_length=2)
samples = [
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([10, 20, 30])},
Expand Down
9 changes: 4 additions & 5 deletions tests/data/test_deita.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from unittest import mock

from litgpt.data import Deita, SFTDataset
from litgpt.data.deita import format_dataset
from litgpt.prompts import Alpaca as AlpacaPromptStyle

def test_format_dataset():
from litgpt.data.deita import format_dataset

def test_format_dataset():
data = [
{
"prompt": "prompt1",
Expand Down Expand Up @@ -43,9 +45,6 @@ def test_format_dataset():
@mock.patch("litgpt.data.deita.format_dataset")
@mock.patch("datasets.load_dataset")
def test_deita(_, format_dataset_mock, mock_tokenizer, tmp_path):
from litgpt.data import Deita, SFTDataset
from litgpt.prompts import Alpaca as AlpacaPromptStyle

format_dataset_mock.return_value = [
{"instruction": "inst1", "output": "out1"},
{"instruction": "inst2", "output": "out2"},
Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_dolly.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

from litgpt.data import Dolly
from litgpt.prompts import Alpaca as AlpacaPromptStyle

def test_dolly(mock_tokenizer, dolly_path):
from litgpt.data import Dolly
from litgpt.prompts import Alpaca as AlpacaPromptStyle

def test_dolly(mock_tokenizer, dolly_path):
alpaca = Dolly(val_split_fraction=0.5, download_dir=dolly_path.parent, file_name=dolly_path.name, num_workers=0)
assert isinstance(alpaca.prompt_style, AlpacaPromptStyle)
alpaca.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
Expand Down
11 changes: 4 additions & 7 deletions tests/data/test_json.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
import json

import pytest

from litgpt.data import JSON
from litgpt.prompts import PromptStyle


@pytest.mark.parametrize("as_jsonl", [False, True])
def test_json(as_jsonl, tmp_path, mock_tokenizer):
from litgpt.data import JSON
from litgpt.prompts import PromptStyle

class Style(PromptStyle):
def apply(self, prompt, **kwargs):
return f"X: {prompt} {kwargs['input']} Y:"
Expand Down Expand Up @@ -62,8 +63,6 @@ def apply(self, prompt, **kwargs):


def test_json_input_validation(tmp_path):
from litgpt.data import JSON

with pytest.raises(FileNotFoundError, match="The `json_path` must be a file or a directory"):
JSON(tmp_path / "not exist")

Expand All @@ -85,8 +84,6 @@ def test_json_input_validation(tmp_path):

@pytest.mark.parametrize("as_jsonl", [False, True])
def test_json_with_splits(as_jsonl, tmp_path, mock_tokenizer):
from litgpt.data import JSON

mock_train_data = [
{"instruction": "Add", "input": "2+2", "output": "4"},
{"instruction": "Subtract", "input": "5-3", "output": "2"},
Expand Down
5 changes: 2 additions & 3 deletions tests/data/test_longform.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from litgpt.data import LongForm
from litgpt.prompts import Longform as LongFormPromptStyle


def test_longform(mock_tokenizer, longform_path):
from litgpt.data import LongForm
from litgpt.prompts import Longform as LongFormPromptStyle

alpaca = LongForm(download_dir=longform_path, num_workers=0)
assert isinstance(alpaca.prompt_style, LongFormPromptStyle)
alpaca.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
Expand Down
8 changes: 4 additions & 4 deletions tests/data/test_openwebtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
from unittest.mock import ANY, call

import pytest
from litdata.streaming import StreamingDataLoader, StreamingDataset
from torch.utils.data import DataLoader

from litgpt.data import OpenWebText


@pytest.mark.skipif(sys.platform == "win32", reason="Not in the mood to add Windows support right now.")
@mock.patch("litdata.optimize")
@mock.patch("datasets.load_dataset")
def test_openwebtext(_, optimize_mock, tmp_path, monkeypatch, mock_tokenizer):
from litgpt.data import OpenWebText
from litdata.streaming import StreamingDataLoader, StreamingDataset

def test_openwebtext(_, optimize_mock, tmp_path, mock_tokenizer):
data = OpenWebText(data_path=(tmp_path / "openwebtext"))
assert data.seq_length == 2048
assert data.batch_size == 1
Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_tinyllama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.

import pytest
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset
from torch.utils.data import DataLoader

from litgpt.data import TinyLlama

def test_tinyllama(tmp_path, monkeypatch):
from litgpt.data import TinyLlama
from litdata.streaming import StreamingDataLoader, StreamingDataset, CombinedStreamingDataset

def test_tinyllama(tmp_path):
data = TinyLlama(data_path=(tmp_path / "data"))
assert data.seq_length == 2048
assert data.batch_size == 1
Expand Down
Loading

0 comments on commit 9b6475d

Please sign in to comment.