Skip to content

Commit

Permalink
fix: pyarrow filter
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Apr 7, 2023
1 parent 7a9f6d1 commit 4b51e6e
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from transformers import DefaultDataCollator
from torch.utils.data import DataLoader, DistributedSampler
import numpy as np
import pyarrow as pa
from pyarrow import compute as pc


def calc_cross_entropy_no_reduction(lm_logits, labels):
Expand Down Expand Up @@ -116,7 +118,13 @@ def inference(config):
df_train = df_train.sort("index")
curr_idx = df_train["index"]

filtered_train = train_dataset.filter(lambda example: example["index"] in curr_idx)
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = train_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_train = Dataset.from_dict(filtered_table.to_pydict())

filtered_train = filtered_train.add_column("embeddings", df_train["embeddings"])
filtered_train = filtered_train.add_column("loss", df_train["loss"])
Expand Down Expand Up @@ -167,7 +175,13 @@ def inference(config):
df_val = df_val.sort("index")
curr_idx = df_val["index"]

filtered_val = val_dataset.filter(lambda example: example["index"] in curr_idx)
# compute mask in pyarrow since it's super fast
# ty @bmschmidt for showing me this!
table = val_dataset.data
mask = pc.is_in(table['index'], value_set=pa.array(curr_idx, pa.int32()))
filtered_table = table.filter(mask)
# convert from pyarrow to Dataset
filtered_val = Dataset.from_dict(filtered_table.to_pydict())

filtered_val = filtered_val.add_column("embeddings", df_val["embeddings"])
filtered_val = filtered_val.add_column("loss", df_val["loss"])
Expand Down

0 comments on commit 4b51e6e

Please sign in to comment.