Skip to content

Commit

Permalink
Refactor current_features to selected_feature_format (huggingface#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
mathemakitten authored Oct 5, 2022
1 parent 8e76263 commit 9463447
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions src/evaluate/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def __init__(
self.add.__func__.__doc__ += self.info.inputs_description

# self.arrow_schema = pa.schema(field for field in self.info.features.type)
self.current_features = None
self.selected_feature_format = None
self.buf_writer = None
self.writer = None
self.writer_batch_size = None
Expand Down Expand Up @@ -377,7 +377,7 @@ def _finalize(self):

if self.keep_in_memory:
# Read the predictions and references
reader = ArrowReader(path=self.data_dir, info=DatasetInfo(features=self.current_features))
reader = ArrowReader(path=self.data_dir, info=DatasetInfo(features=self.selected_feature_format))
self.data = Dataset.from_buffer(self.buf_writer.getvalue())

elif self.process_id == 0:
Expand All @@ -386,7 +386,7 @@ def _finalize(self):

# Read the predictions and references
try:
reader = ArrowReader(path="", info=DatasetInfo(features=self.current_features))
reader = ArrowReader(path="", info=DatasetInfo(features=self.selected_feature_format))
self.data = Dataset(**reader.read_files([{"filename": f} for f in file_paths]))
except FileNotFoundError:
raise ValueError(
Expand Down Expand Up @@ -434,7 +434,7 @@ def compute(self, *, predictions=None, references=None, **kwargs) -> Optional[di

self.cache_file_name = None
self.filelock = None
self.current_features = None
self.selected_feature_format = None

if self.process_id == 0:
self.data.set_format(type=self.info.format)
Expand Down Expand Up @@ -477,13 +477,13 @@ def add_batch(self, *, predictions=None, references=None, **kwargs):
batch = {"predictions": predictions, "references": references, **kwargs}
batch = {input_name: batch[input_name] for input_name in self._feature_names()}
if self.writer is None:
self.current_features = self._infer_feature_from_batch(batch)
self.selected_feature_format = self._infer_feature_from_batch(batch)
self._init_writer()
try:
for key, column in batch.items():
if len(column) > 0:
self._enforce_nested_string_type(self.current_features[key], column[0])
batch = self.current_features.encode_batch(batch)
self._enforce_nested_string_type(self.selected_feature_format[key], column[0])
batch = self.selected_feature_format.encode_batch(batch)
self.writer.write_batch(batch)
except (pa.ArrowInvalid, TypeError):
if any(len(batch[c]) != len(next(iter(batch.values()))) for c in batch):
Expand All @@ -492,19 +492,20 @@ def add_batch(self, *, predictions=None, references=None, **kwargs):
error_msg = (
f"Mismatch in the number of {col0} ({len(batch[col0])}) and {bad_col} ({len(batch[bad_col])})"
)
elif set(self.current_features) != {"references", "predictions"}:
elif set(self.selected_feature_format) != {"references", "predictions"}:
error_msg = (
f"Module inputs don't match the expected format.\n" f"Expected format: {self.current_features },\n"
f"Module inputs don't match the expected format.\n"
f"Expected format: {self.selected_feature_format },\n"
)
error_msg_inputs = ",\n".join(
f"Input {input_name}: {summarize_if_long_list(batch[input_name])}"
for input_name in self.current_features
for input_name in self.selected_feature_format
)
error_msg += error_msg_inputs
else:
error_msg = (
f"Predictions and/or references don't match the expected format.\n"
f"Expected format: {self.current_features },\n"
f"Expected format: {self.selected_feature_format },\n"
f"Input predictions: {summarize_if_long_list(predictions)},\n"
f"Input references: {summarize_if_long_list(references)}"
)
Expand All @@ -525,7 +526,7 @@ def add(self, *, prediction=None, reference=None, **kwargs):
example = {"predictions": prediction, "references": reference, **kwargs}
example = {input_name: example[input_name] for input_name in self._feature_names()}
if self.writer is None:
self.current_features = self._infer_feature_from_example(example)
self.selected_feature_format = self._infer_feature_from_example(example)
self._init_writer()
try:
self._enforce_nested_string_type(self.info.features, example)
Expand All @@ -534,11 +535,11 @@ def add(self, *, prediction=None, reference=None, **kwargs):
except (pa.ArrowInvalid, TypeError):
error_msg = (
f"Evaluation module inputs don't match the expected format.\n"
f"Expected format: {self.current_features},\n"
f"Expected format: {self.selected_feature_format},\n"
)
error_msg_inputs = ",\n".join(
f"Input {input_name}: {summarize_if_long_list(example[input_name])}"
for input_name in self.current_features
for input_name in self.selected_feature_format
)
error_msg += error_msg_inputs
raise ValueError(error_msg) from None
Expand Down Expand Up @@ -594,7 +595,7 @@ def _init_writer(self, timeout=1):
if self.keep_in_memory:
self.buf_writer = pa.BufferOutputStream()
self.writer = ArrowWriter(
features=self.current_features, stream=self.buf_writer, writer_batch_size=self.writer_batch_size
features=self.selected_feature_format, stream=self.buf_writer, writer_batch_size=self.writer_batch_size
)
else:
self.buf_writer = None
Expand All @@ -606,7 +607,9 @@ def _init_writer(self, timeout=1):
self.filelock = filelock

self.writer = ArrowWriter(
features=self.current_features, path=self.cache_file_name, writer_batch_size=self.writer_batch_size
features=self.selected_feature_format,
path=self.cache_file_name,
writer_batch_size=self.writer_batch_size,
)
# Setup rendez-vous here if
if self.num_process > 1:
Expand Down

0 comments on commit 9463447

Please sign in to comment.