Skip to content

Commit

Permalink
Add packed dataset (Lightning-AI#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
lantiga authored Apr 26, 2023
1 parent 5f484dd commit d3c0075
Show file tree
Hide file tree
Showing 2 changed files with 353 additions and 0 deletions.
229 changes: 229 additions & 0 deletions lit_llama/packed_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
# Very loosely inspired by indexed_dataset in Fairseq, Megatron
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py


import os
import struct

import numpy as np
from torch.utils.data import IterableDataset, get_worker_info


dtypes = {
1: np.uint8,
2: np.int8,
3: np.int16,
4: np.int32,
5: np.int64,
6: np.float32,
7: np.float64,
8: np.uint16,
}


def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
return k
raise ValueError(dtype)


HDR_MAGIC = b"LITPKDS"
HDR_SIZE = 24 # bytes


class PackedDataset(IterableDataset):
def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True):
self._filenames = filenames
self._n_chunks = n_chunks
self._block_size = block_size
self._seed = seed
self._shuffle = shuffle

def __iter__(self):
worker_info = get_worker_info()
if worker_info is None:
return PackedDatasetIterator(
filenames=self._filenames,
n_chunks=self._n_chunks,
block_size=self._block_size,
seed=self._seed,
shuffle=self._shuffle,
)
else:
return PackedDatasetIterator(
filenames=[
el
for idx, el in enumerate(self._filenames)
if idx % worker_info.num_workers == worker_info.id
],
n_chunks=self._n_chunks,
block_size=self._block_size,
seed=self._seed,
shuffle=self._shuffle,
)


class PackedDatasetBuilder(object):
def __init__(
self,
outdir,
prefix,
chunk_size,
sep_token,
dtype="auto",
vocab_size=None,
):
if dtype == "auto":
if vocab_size is None:
raise ValueError("vocab_size cannot be None when dtype='auto'")
if vocab_size is not None and vocab_size < 65500:
self._dtype = np.uint16
else:
self._dtype = np.int32
else:
self._dtype = dtype
self._counter = 0
self._chunk_size = chunk_size
self._outdir = outdir
self._prefix = prefix
self._sep_token = sep_token
self._arr = np.zeros(self._chunk_size, dtype=self._dtype)
self._arr.fill(self._sep_token)
self._idx = 0
self._version = 1
self._filenames = []

def _write_chunk(self):
filename = f"{self._prefix}_{self._counter:010d}.bin"
filename = os.path.join(self._outdir, filename)

with open(filename, "wb") as f:
f.write(HDR_MAGIC)
f.write(struct.pack("<Q", self._version))
f.write(struct.pack("<B", code(self._dtype)))
f.write(struct.pack("<Q", self._chunk_size))
f.write(self._arr.tobytes(order="C"))

self._filenames.append(filename)
self._counter += 1
self._arr.fill(self._sep_token)
self._idx = 0

@property
def dtype(self):
return self._dtype

@property
def filenames(self):
return self._filenames.copy()

def add_array(self, arr):
while self._idx + arr.shape[0] > self._chunk_size:
part_len = self._chunk_size - self._idx
self._arr[self._idx : self._idx + part_len] = arr[:part_len]
self._write_chunk()
arr = arr[part_len:]

arr_len = arr.shape[0]
self._arr[self._idx : self._idx + arr_len] = arr
self._idx += arr_len

def write_reminder(self):
self._write_chunk()


class PackedDatasetIterator:
def __init__(self, filenames, n_chunks, block_size, seed, shuffle):
self._seed = seed
self._shuffle = shuffle
self._rng = np.random.default_rng(seed) if shuffle else None
self._block_idxs = None

# TODO: instead of filenames, we could have a single text stream
# (or text file) with the sequence of all files to be
# fetched/loaded.
self._filenames = filenames
self._file_idx = 0

self._n_chunks = n_chunks

self._dtype = None
self._block_size = block_size
self._n_blocks = None

self._mmaps = []
self._buffers = []

self._block_idxs = []
self._curr_idx = 0

self._load_n_chunks()

def _read_header(self, path):
with open(path, "rb") as f:
magic = f.read(len(HDR_MAGIC))
assert magic == HDR_MAGIC, "File doesn't match expected format."
version = struct.unpack("<Q", f.read(8))
assert (1,) == version
(dtype_code,) = struct.unpack("<B", f.read(1))
dtype = dtypes[dtype_code]
(chunk_size,) = struct.unpack("<Q", f.read(8))
return dtype, chunk_size

def _close_mmaps(self):
for mmap in self._mmaps:
mmap._mmap.close()

def _load_n_chunks(self):
self._close_mmaps()

if self._n_chunks > len(self._filenames[self._file_idx :]):
raise StopIteration

for i in range(self._n_chunks):
filename = self._filenames[self._file_idx + i]
if self._dtype is None:
self._dtype, self._chunk_size = self._read_header(
filename
)
self._n_blocks = self._chunk_size // self._block_size
# TODO: check header matches with previous files
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
self._mmaps.append(mmap)
self._buffers.append(memoryview(mmap))

self._file_idx += self._n_chunks
n_all_blocks = self._n_chunks * self._n_blocks

self._block_idxs = (
self._rng.permutation(n_all_blocks)
if self._shuffle
else range(n_all_blocks)
)

self._curr_idx = 0

def __del__(self):
self._close_mmaps()
del self._mmaps

def __iter__(self):
return self

def __next__(self):
if self._curr_idx >= len(self._block_idxs):
self._load_n_chunks()
# TODO: trigger fetching next next n_chunks if remote
block_idx = self._block_idxs[self._curr_idx]
chunk_id = block_idx // self._n_blocks
buffer = self._buffers[chunk_id]
elem_id = (block_idx % self._n_blocks) * self._block_size
offset = np.dtype(self._dtype).itemsize * elem_id
arr = np.frombuffer(
buffer, dtype=self._dtype, count=self._block_size, offset=offset
)
self._curr_idx += 1
return arr


124 changes: 124 additions & 0 deletions tests/test_packed_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import pytest
import os
import requests


def train_tokenizer(destination_path):
destination_path.mkdir(parents=True, exist_ok=True)

# download the tiny shakespeare dataset
input_file_path = destination_path / "input.txt"
if not input_file_path.exists():
data_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
with open(input_file_path, "w") as f:
f.write(requests.get(data_url).text)

from lit_llama import Tokenizer
Tokenizer.train(
input=input_file_path,
destination=destination_path,
vocab_size=100,
)

return destination_path / "tokenizer.model"


def test_packed_dataset(tmp_path):
tokenizer_path = train_tokenizer(tmp_path)

from lit_llama import Tokenizer
tokenizer = Tokenizer(tokenizer_path)

texts = [
"The moment of truth is upon us.",
"Time to open the fridge."
]

from lit_llama.packed_dataset import PackedDatasetBuilder, PackedDataset, HDR_SIZE

block_size = 10
n_blocks = 2
chunk_size = block_size * n_blocks

builder = PackedDatasetBuilder(
outdir=tmp_path,
prefix="packed_dataset",
chunk_size=chunk_size,
sep_token=tokenizer.bos_id,
dtype="auto",
vocab_size=100,
)

text_ids = []

for text in texts:
text_ids = tokenizer.encode(text)
assert text_ids[0] == tokenizer.bos_id
builder.add_array(text_ids)

filenames = builder.filenames

assert len(filenames) == 2
assert os.path.basename(filenames[0]) == "packed_dataset_0000000000.bin"
assert os.path.basename(filenames[1]) == "packed_dataset_0000000001.bin"

import numpy as np

ex_tokenized = [
tokenizer.encode(text).numpy().astype(builder.dtype)
for text in texts
]
ex_tokenized = np.concatenate(ex_tokenized)
ex_tokenized = ex_tokenized[:2 * chunk_size]

for filename, el in zip(filenames, np.array_split(ex_tokenized, 2)):
mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE)
count = len(mmap) // np.dtype(builder.dtype).itemsize
arr = np.frombuffer(
mmap, dtype=builder.dtype, count=count, offset=0
)
where_bos = np.where(arr == tokenizer.bos_id)
# we expect two BOS tokens, one per file
assert len(where_bos) == 1
assert np.array_equal(arr, el)

dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, shuffle=False)

ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size)

for item, el in zip(dataset, ex_split):
assert np.array_equal(item, el)

dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size, seed=12345)

for i, item in enumerate(dataset):
block_idxs = iter(dataset)._block_idxs
assert np.array_equal(item, ex_split[block_idxs[i]])

dataset = PackedDataset(filenames=filenames, n_chunks=1, block_size=block_size, seed=12345)

for i, item in enumerate(dataset):
block_idxs = iter(dataset)._block_idxs
chunk_idx = i // n_blocks * n_blocks
assert np.array_equal(item, ex_split[chunk_idx + block_idxs[i % n_blocks]])

block_size_ = block_size // 2
ex_split = np.array_split(ex_tokenized, ex_tokenized.shape[0] // block_size_)
dataset = PackedDataset(filenames=filenames, n_chunks=2, block_size=block_size_, seed=12345)

for i, item in enumerate(dataset):
block_idxs = iter(dataset)._block_idxs
assert np.array_equal(item, ex_split[block_idxs[i]])

block_size_ = block_size // 3
n_chunks = 2
ex_chunks = np.split(ex_tokenized, n_chunks)
n_splits = ex_tokenized.shape[0] // n_chunks // block_size_
ex_splits = [np.split(el[:n_splits * block_size_], n_splits) for el in ex_chunks]
ex_split = sum(ex_splits, [])

dataset = PackedDataset(filenames=filenames, n_chunks=n_chunks, block_size=block_size_, seed=12345)

for i, item in enumerate(dataset):
block_idxs = iter(dataset)._block_idxs
assert np.array_equal(item, ex_split[block_idxs[i]])

0 comments on commit d3c0075

Please sign in to comment.