Skip to content

Commit

Permalink
fix: data for inference
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Apr 7, 2023
1 parent fb9ff9c commit 1b14b1f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
48 changes: 47 additions & 1 deletion data.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def load_data(config, tokenizer):
dataset = load_dataset("json", data_files=files, split="train")

else:
dataset = load_dataset(dataset_path)
dataset = load_dataset(dataset_path, split="train")

dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])

Expand Down Expand Up @@ -118,3 +118,49 @@ def load_data(config, tokenizer):
)

return train_dataloader, val_dataloader


def load_data_for_inference(config, tokenizer):
dataset_path = config["dataset_path"]

if os.path.exists(dataset_path):
# check if path is a directory
if os.path.isdir(dataset_path):
files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl"))
else:
files = [dataset_path]

print(f"Reading files {files}")

dataset = load_dataset("json", data_files=files, split="train")

else:
dataset = load_dataset(dataset_path, split="train")

dataset = dataset.train_test_split(test_size=.05, seed=config["seed"])

train_dataset, val_dataset = dataset["train"], dataset["test"]

train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))

if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
else:
kwargs = {}

# tokenize inputs and return labels and attention mask
train_dataset = train_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
**kwargs
)
val_dataset = val_dataset.map(
lambda ele: tokenize_inputs(config, tokenizer, ele),
batched=True,
**kwargs
)
train_dataset = train_dataset.with_format("torch")
val_dataset = val_dataset.with_format("torch")

return train_dataset, val_dataset
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ peft
nodelist-inflator
deepspeed
sentencepiece
jsonlines
jsonlines
nomic

0 comments on commit 1b14b1f

Please sign in to comment.