Skip to content

Commit

Permalink
fix: add eos
Browse files Browse the repository at this point in the history
  • Loading branch information
zanussbaum committed Mar 26, 2023
1 parent 2daecd6 commit eac7734
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions data.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,24 @@ def tokenize_inputs(config, tokenizer, examples):

out = {"labels": [], "attention_mask": []}
for i, (prompt, response) in enumerate(zip(examples["prompt"], examples["response"])):
# HACK to get 512 to work for now
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length //2, return_tensors="pt")["input_ids"].squeeze()
input_tokens = tokenizer(prompt, truncation=True, max_length=max_length // 2, return_tensors="pt")["input_ids"].squeeze()
input_len = len(input_tokens)

# plus one since we remove bos from response
remaining_tokens = max_length - input_len - len(newline_tokens) + 1

# but we subtract one since we want to add eos token
remaining_tokens = max_length - input_len - len(newline_tokens)
# remove bos
target_tokens = tokenizer(response, truncation=True, max_length=remaining_tokens, return_tensors="pt")["input_ids"].squeeze()[1:]

input_ids[i, :input_len] = input_tokens
# add newline between prompt and response
newline_plus_inputs = input_len + len(newline_tokens)
input_ids[i, input_len: newline_plus_inputs] = newline_tokens

# add target tokens, remove bos
input_ids[i, newline_plus_inputs: newline_plus_inputs + len(target_tokens)] = target_tokens
# add eos token, enforce stopping
input_ids[i, newline_plus_inputs + len(target_tokens)] = tokenizer.eos_token_id

labels = input_ids[i].clone()
labels[: newline_plus_inputs] = -100
Expand Down

0 comments on commit eac7734

Please sign in to comment.