Skip to content

Commit

Permalink
Update utils.py
Browse files Browse the repository at this point in the history
fix predictor cuda settings
  • Loading branch information
marscrazy authored Nov 20, 2022
1 parent c430736 commit 7d04894
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions flagai/model/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,7 @@ def glm_generate_sample(
eod_token=50000,
temperature=0.9,
):
device = torch.cuda.current_device()
device = next(model.parameters()).device
model.eval()

generation_mask = '[gMASK]'
Expand All @@ -1017,9 +1017,9 @@ def glm_generate_sample(
if not text.endswith('[gMASK]'):
context_tokens = context_tokens + [tokenizer.get_command_id('eos')]
context_length = len(context_tokens)
context_length_tensor = torch.cuda.LongTensor([context_length])
context_length_tensor = torch.LongTensor([context_length])
context_length = context_length_tensor[0].item()
context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
context_tokens_tensor = torch.LongTensor(context_tokens)
text = tokenizer.DecodeIds(context_tokens_tensor.tolist())

start_time = time.time()
Expand Down Expand Up @@ -1576,6 +1576,7 @@ def cpm_beam_search(model,
min_len=None,
**kwags):
print('tokenizer is', tokenizer)
device = next(model.parameters()).device
vocab_size = tokenizer.vocab_size

ids, info = encode(tokenizer, instance, target_span_len)
Expand Down Expand Up @@ -1613,12 +1614,12 @@ def cpm_beam_search(model,
segment_input = segment_input.contiguous().view(batch_size * beam_size,
max_length)

input_tokens = input_tokens.int().cuda()
input_length = input_length.int().cuda()
context_input = context_input.bool().cuda()
position_input = position_input.float().cuda()
segment_input = segment_input.int().cuda()
span_input = span_input.int().cuda()
input_tokens = input_tokens.int().to(device)
input_length = input_length.int().to(device)
context_input = context_input.bool().to(device)
position_input = position_input.float().to(device)
segment_input = segment_input.int().to(device)
span_input = span_input.int().to(device)

done = [False for _ in range(batch_size)]
# (batch_size * beam_size, 0)
Expand Down

0 comments on commit 7d04894

Please sign in to comment.