Skip to content

Commit

Permalink
remove mpu dependency in zeroshot script
Browse files Browse the repository at this point in the history
  • Loading branch information
skyw committed Dec 21, 2022
1 parent 52e6368 commit 8ed3887
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions tasks/zeroshot_gpt/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.

"""GPT zero-shot evaluation."""

Expand All @@ -9,7 +9,7 @@
from megatron import get_args
from megatron import print_rank_0, is_last_rank
from megatron import get_tokenizer
from megatron.core import mpu
from megatron.core import parallel_state, tensor_parallel
from megatron.checkpointing import load_checkpoint
from megatron.model import GPTModel
from megatron.training import get_model
Expand Down Expand Up @@ -90,10 +90,10 @@ def forward_step(batch, model, eval_metric):

send_forward(output)

if mpu.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage():
# For loss, return the unreduced loss.
if eval_metric == 'loss':
losses = mpu.tensor_parallel.vocab_parallel_cross_entropy(
losses = tensor_parallel.vocab_parallel_cross_entropy(
output.contiguous().float(), labels.contiguous())
loss = torch.sum(
losses.view(-1) * loss_mask.contiguous().view(-1).float())
Expand Down Expand Up @@ -129,9 +129,9 @@ def evaluate(data_loader, model, eval_metric):
output = forward_step(batch, model, eval_metric)

# Reduce across processes.
if mpu.is_pipeline_last_stage():
if parallel_state.is_pipeline_last_stage():
torch.distributed.all_reduce(output,
group=mpu.get_data_parallel_group())
group=parallel_state.get_data_parallel_group())

total_output += output

Expand Down

0 comments on commit 8ed3887

Please sign in to comment.