Skip to content

Commit

Permalink
[CodeLlama]: simplify infilling with <FILL_ME> (huggingface#1424)
Browse files Browse the repository at this point in the history
* CodeLlama: simplify infilling with `<FILL_ME>`.

Co-authored-by: Arthur <[email protected]>

* Use `AutoTokenizer`

It now works after [these PRs](https://huggingface.co/codellama/CodeLlama-7b-hf/discussions/11) have been merged.

---------

Co-authored-by: Arthur <[email protected]>
  • Loading branch information
pcuenca and ArthurZucker authored Aug 28, 2023
1 parent 644d267 commit dff73a8
Showing 1 changed file with 21 additions and 19 deletions.
40 changes: 21 additions & 19 deletions codellama.md
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ This is a specialized task particular to code models. The model is trained to ge

This task is available in the **base** and **instruction** variants of the 7B and 13B models. It is _not_ available for any of the 34B models or the Python versions.

To use this feature successfully, you need to pay close attention to the format used to train the model for this task, as it uses special separators to identify the different parts of the prompt. Let's see an example:
To use this feature successfully, you need to pay close attention to the format used to train the model for this task, as it uses special separators to identify the different parts of the prompt. Fortunately, transformers' `CodeLlamaTokenizer` makes this very easy, as demonstrated below:

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
Expand All @@ -165,38 +165,40 @@ model = AutoModelForCausalLM.from_pretrained(
torch_dtype=torch.float16
).to("cuda")

prefix = 'def remove_non_ascii(s: str) -> str:\n """ '
suffix = "\n return result\n"

prompt = f"<PRE> {prefix} <SUF>{suffix} <MID>"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
prompt = '''def remove_non_ascii(s: str) -> str:
""" <FILL_ME>
return result
'''

input_ids = tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
output = model.generate(
inputs["input_ids"],
input_ids,
max_new_tokens=200,
do_sample=False,
)
output = output[0].to("cpu")
print(tokenizer.decode(output))
```

filling = tokenizer.decode(output[input_ids.shape[1]:], skip_special_tokens=True)
print(prompt.replace("<FILL_ME>", filling))
```
<s> <PRE> def remove_non_ascii(s: str) -> str:
""" <SUF>
return result
<MID>
Remove non-ASCII characters from a string.

:param s: The string to remove non-ASCII characters from.
:return: The string with non-ASCII characters removed.
```Python
def remove_non_ascii(s: str) -> str:
""" Remove non-ASCII characters from a string.
Args:
s: The string to remove non-ASCII characters from.
Returns:
The string with non-ASCII characters removed.
"""
result = ""
for c in s:
if ord(c) < 128:
result += c <EOT></s>
result += c
return result
```

In order to use the completion, you’ll need to process the output to cut the text between the `<MID>` and `<EOT>` tokens – that’s what goes between the prefix and suffix we supplied.
Under the hood, the tokenizer [automatically splits by `<FILL_ME>`](https://huggingface.co/docs/transformers/main/model_doc/code_llama#transformers.CodeLlamaTokenizer.fill_token) to create a formatted input string that follows [the original training pattern](https://github.com/facebookresearch/codellama/blob/cb51c14ec761370ba2e2bc351374a79265d0465e/llama/generation.py#L402). This is more robust than preparing the pattern yourself: it avoids pitfalls, such as token glueing, that are very hard to debug.

#### Conversational Instructions

Expand Down

0 comments on commit dff73a8

Please sign in to comment.