Skip to content

Commit

Permalink
fix: drop uneven batch size
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Apr 7, 2023
1 parent 985da51 commit 0bd6acb
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
4 changes: 4 additions & 0 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,11 @@ def load_data_for_inference(config, tokenizer):
train_dataset, val_dataset = dataset["train"], dataset["test"]

train_dataset = train_dataset.add_column("index", list(range(len(train_dataset))))
# select first N batches that are divisible by batch_size
# gather is a bit annoying (or the way I'm using it) to get uneven batches as it duplicates data
train_dataset = train_dataset.select(range((len(train_dataset) // config["batch_size"]) * config["batch_size"]))
val_dataset = val_dataset.add_column("index", list(range(len(val_dataset))))
val_dataset = val_dataset.select(range((len(val_dataset) // config["batch_size"]) * config["batch_size"]))

if config["streaming"] is False:
kwargs = {"num_proc": config["num_proc"]}
Expand Down
11 changes: 6 additions & 5 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,22 @@ def inference(config):
num_processes = dist.get_world_size()
local_rank = dist.get_rank()

train_sampler = ShardSampler(train_dataset, config["batch_size"], num_processes=num_processes, process_index=local_rank)
train_sampler = ShardSampler(train_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank)
train_dataloader = DataLoader(
train_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=train_sampler
sampler=train_sampler,
drop_last=True
)

val_sampler = ShardSampler(val_dataset, config["batch_size"], num_processes=num_processes, process_index=local_rank)
val_sampler = ShardSampler(val_dataset, config["batch_size"], drop_last=True, num_processes=num_processes, process_index=local_rank)
val_dataloader = DataLoader(
val_dataset,
collate_fn=DefaultDataCollator(),
batch_size=config["batch_size"],
sampler=val_sampler
sampler=val_sampler,
drop_last=True
)


Expand Down Expand Up @@ -113,7 +115,6 @@ def inference(config):

df_train = Dataset.from_dict(gathered_train)
df_train = df_train.sort("index")

train_dataset = train_dataset.add_column("embeddings", df_train["embeddings"])
train_dataset = train_dataset.add_column("loss", df_train["loss"])
train_dataset = train_dataset.add_column("is_train", [True] * len(train_dataset))
Expand Down

0 comments on commit 0bd6acb

Please sign in to comment.