Skip to content

Commit

Permalink
fix TPU changes
Browse files Browse the repository at this point in the history
  • Loading branch information
minimaxir committed Jan 2, 2021
1 parent 6f0f497 commit 94b6eeb
Showing 1 changed file with 2 additions and 14 deletions.
16 changes: 2 additions & 14 deletions aitextgen/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import get_linear_schedule_with_warmup
import os
import shutil
Expand All @@ -28,29 +29,16 @@ def forward(self, inputs):
return self.model(**inputs, return_dict=False)

def training_step(self, batch, batch_num):
"Compute loss and log."

outputs = self({"input_ids": batch, "labels": batch})
loss = outputs[0]

return {"loss": loss}

def train_dataloader(self):
"Load datasets. Called after prepare data."
sampler = None
if self.hparams.use_tpu:
sampler = torch.utils.data.distributed.DistributedSampler(
self.dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True,
)

return DataLoader(
self.dataset,
sampler=sampler,
batch_size=self.hparams["batch_size"],
shuffle=not sampler,
shuffle=True,
pin_memory=self.hparams["pin_memory"],
num_workers=self.hparams["num_workers"],
)
Expand Down

0 comments on commit 94b6eeb

Please sign in to comment.