diff --git a/data.py b/data.py index ff519abb5447..6375584d9fba 100644 --- a/data.py +++ b/data.py @@ -75,7 +75,7 @@ def load_data(config, tokenizer): dataset = load_dataset("json", data_files=files, split="train") else: - dataset = load_dataset(dataset_path) + dataset = load_dataset(dataset_path, split="train") dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) @@ -118,3 +118,49 @@ def load_data(config, tokenizer): ) return train_dataloader, val_dataloader + + +def load_data_for_inference(config, tokenizer): + dataset_path = config["dataset_path"] + + if os.path.exists(dataset_path): + # check if path is a directory + if os.path.isdir(dataset_path): + files = glob.glob(os.path.join(dataset_path, "*_clean.jsonl")) + else: + files = [dataset_path] + + print(f"Reading files {files}") + + dataset = load_dataset("json", data_files=files, split="train") + + else: + dataset = load_dataset(dataset_path, split="train") + + dataset = dataset.train_test_split(test_size=.05, seed=config["seed"]) + + train_dataset, val_dataset = dataset["train"], dataset["test"] + + train_dataset = train_dataset.add_column("index", list(range(len(train_dataset)))) + val_dataset = val_dataset.add_column("index", list(range(len(val_dataset)))) + + if config["streaming"] is False: + kwargs = {"num_proc": config["num_proc"]} + else: + kwargs = {} + + # tokenize inputs and return labels and attention mask + train_dataset = train_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + **kwargs + ) + val_dataset = val_dataset.map( + lambda ele: tokenize_inputs(config, tokenizer, ele), + batched=True, + **kwargs + ) + train_dataset = train_dataset.with_format("torch") + val_dataset = val_dataset.with_format("torch") + + return train_dataset, val_dataset diff --git a/requirements.txt b/requirements.txt index 8a91fd74b60d..a30fc8a555cf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ peft nodelist-inflator deepspeed sentencepiece -jsonlines \ No newline at end of file +jsonlines +nomic \ No newline at end of file