Skip to content

Commit

Permalink
[Examples] Added predict stage and Updated Example Template (huggingf…
Browse files Browse the repository at this point in the history
…ace#10868)

* added predict stage

* added test keyword in exception message

* removed example specific saving predictions

* fixed f-string error

* removed extra line

Co-authored-by: Stas Bekman <[email protected]>

Co-authored-by: Stas Bekman <[email protected]>
  • Loading branch information
bhadreshpsavani and stas00 authored Mar 23, 2021
1 parent fb2b898 commit 7ef4012
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 14 deletions.
56 changes: 46 additions & 10 deletions examples/text-classification/run_xnli.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,22 @@ def main():
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
# Downloading and loading xnli dataset from the hub.
if model_args.train_language is None:
train_dataset = load_dataset("xnli", model_args.language, split="train")
else:
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
if training_args.do_train:
if model_args.train_language is None:
train_dataset = load_dataset("xnli", model_args.language, split="train")
else:
train_dataset = load_dataset("xnli", model_args.train_language, split="train")
label_list = train_dataset.features["label"].names

if training_args.do_eval:
eval_dataset = load_dataset("xnli", model_args.language, split="validation")
label_list = eval_dataset.features["label"].names

if training_args.do_predict:
test_dataset = load_dataset("xnli", model_args.language, split="test")
label_list = test_dataset.features["label"].names

eval_dataset = load_dataset("xnli", model_args.language, split="validation")
# Labels
label_list = train_dataset.features["label"].names
num_labels = len(label_list)

# Load pretrained model and tokenizer
Expand Down Expand Up @@ -271,6 +279,9 @@ def preprocess_function(examples):
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)
# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")

if training_args.do_eval:
if data_args.max_val_samples is not None:
Expand All @@ -281,9 +292,14 @@ def preprocess_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
)

# Log a few random samples from the training set:
for index in random.sample(range(len(train_dataset)), 3):
logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
if training_args.do_predict:
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
test_dataset = test_dataset.map(
preprocess_function,
batched=True,
load_from_cache_file=not data_args.overwrite_cache,
)

# Get the metric function
metric = load_metric("xnli")
Expand All @@ -307,7 +323,7 @@ def compute_metrics(p: EvalPrediction):
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
train_dataset=train_dataset if training_args.do_train else None,
eval_dataset=eval_dataset if training_args.do_eval else None,
compute_metrics=compute_metrics,
tokenizer=tokenizer,
Expand Down Expand Up @@ -346,6 +362,26 @@ def compute_metrics(p: EvalPrediction):
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset)

max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))

trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)

predictions = np.argmax(predictions, axis=1)
output_test_file = os.path.join(training_args.output_dir, "test_predictions.txt")
if trainer.is_world_process_zero():
with open(output_test_file, "w") as writer:
writer.write("index\tprediction\n")
for index, item in enumerate(predictions):
item = label_list[item]
writer.write(f"{index}\t{item}\n")


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
)
test_file: Optional[str] = field(
default=None,
metadata={"help": "An optional input test data file to predict the label on (a text file)."},
)
overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
)
Expand All @@ -160,17 +164,32 @@ class DataTrainingArguments:
"value if set."
},
)
max_test_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of test examples to this "
"value if set."
},
)

def __post_init__(self):
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if (
self.dataset_name is None
and self.train_file is None
and self.validation_file is None
and self.test_file is None
):
raise ValueError("Need either a dataset name or a training/validation/test file.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
if self.validation_file is not None:
extension = self.validation_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
if self.test_file is not None:
extension = self.test_file.split(".")[-1]
assert extension in ["csv", "json", "txt"], "`test_file` should be a csv, a json or a txt file."


def main():
Expand Down Expand Up @@ -238,9 +257,13 @@ def main():
data_files = {}
if data_args.train_file is not None:
data_files["train"] = data_args.train_file
extension = data_args.train_file.split(".")[-1]
if data_args.validation_file is not None:
data_files["validation"] = data_args.validation_file
extension = data_args.train_file.split(".")[-1]
extension = data_args.validation_file.split(".")[-1]
if data_args.test_file is not None:
data_files["test"] = data_args.test_file
extension = data_args.test_file.split(".")[-1]
if extension == "txt":
extension = "text"
datasets = load_dataset(extension, data_files=data_files)
Expand Down Expand Up @@ -326,8 +349,10 @@ def main():
# First we tokenize all the texts.
if training_args.do_train:
column_names = datasets["train"].column_names
else:
elif training_args.do_eval:
column_names = datasets["validation"].column_names
elif training_args.do_predict:
column_names = datasets["test"].column_names
text_column_name = "text" if "text" in column_names else column_names[0]

def tokenize_function(examples):
Expand Down Expand Up @@ -365,6 +390,22 @@ def tokenize_function(examples):
load_from_cache_file=not data_args.overwrite_cache,
)

if training_args.do_predict:
if "test" not in datasets:
raise ValueError("--do_predict requires a test dataset")
test_dataset = datasets["test"]
# Selecting samples from dataset
if data_args.max_test_samples is not None:
test_dataset = test_dataset.select(range(data_args.max_test_samples))
# tokenize test dataset
test_dataset = test_dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=[text_column_name],
load_from_cache_file=not data_args.overwrite_cache,
)

# Data collator
data_collator=default_data_collator if not training_args.fp16 else DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)

Expand Down Expand Up @@ -420,6 +461,18 @@ def tokenize_function(examples):
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Prediction
if training_args.do_predict:
logger.info("*** Predict ***")
predictions, labels, metrics = trainer.predict(test_dataset)

max_test_samples = data_args.max_test_samples if data_args.max_test_samples is not None else len(test_dataset)
metrics["test_samples"] = min(max_test_samples, len(test_dataset))

trainer.log_metrics("test", metrics)
trainer.save_metrics("test", metrics)

# write custom code for saving predictions according to task

def _mp_fn(index):
# For xla_spawn (TPUs)
Expand Down

0 comments on commit 7ef4012

Please sign in to comment.