Skip to content

Commit

Permalink
Falcon3 (Lightning-AI#1881)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysjprojects authored Dec 23, 2024
1 parent fe96c63 commit 1811ecc
Show file tree
Hide file tree
Showing 6 changed files with 176 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Every model is written from scratch to maximize performance and remove layers of
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
Expand Down
89 changes: 89 additions & 0 deletions litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,95 @@ def norm_class(self) -> Type:
copy["hf_config"]["name"] = falcon180b["hf_config"]["name"].format(kind)
configs.append(copy)

falcon3 = [
# https://huggingface.co/tiiuae/Falcon3-1B-Base/blob/main/config.json
dict(
name="Falcon3-1B{}",
hf_config=dict(org="tiiuae", name="Falcon3-1B{}"),
block_size=4096,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=18,
n_head=8,
n_query_groups=4,
n_embd=2048,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=8192,
),
# https://huggingface.co/tiiuae/Falcon3-3B-Base/blob/main/config.json
dict(
name="Falcon3-3B{}",
hf_config=dict(org="tiiuae", name="Falcon3-3B{}"),
block_size=32768,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=22,
n_head=12,
n_query_groups=4,
n_embd=3072,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=9216,
),
# https://huggingface.co/tiiuae/Falcon3-7B-Base/blob/main/config.json
dict(
name="Falcon3-7B{}",
hf_config=dict(org="tiiuae", name="Falcon3-7B{}"),
block_size=32768,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=28,
n_head=12,
n_query_groups=4,
n_embd=3072,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=23040,
),
# https://huggingface.co/tiiuae/Falcon3-10B-Base/blob/main/config.json
dict(
name="Falcon3-10B{}",
hf_config=dict(org="tiiuae", name="Falcon3-10B{}"),
block_size=32768,
vocab_size=131072,
padded_vocab_size=131072,
n_layer=40,
n_head=12,
n_query_groups=4,
n_embd=3072,
rotary_percentage=1.0,
parallel_residual=False,
rope_base=1000042,
norm_eps=1e-6,
bias=False,
norm_class_name="RMSNorm",
mlp_class_name="LLaMAMLP",
intermediate_size=23040,
),
]
for c in falcon3:
for kind in ("-Base", "-Instruct"):
copy = deepcopy(c)
copy["name"] = c["name"].format(kind)
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
configs.append(copy)


#############################
# OpenLM Research Open LLaMA
Expand Down
13 changes: 13 additions & 0 deletions litgpt/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,17 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
)


class Falcon3(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
return f"<|user|>\n{prompt}<|endoftext|>\n<|assistant|>\n"

def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]:
return (
[tokenizer.eos_id],
[tokenizer.token_to_id("<|endoftext|>")],
)


class Llama2FunctionCalling(PromptStyle):
def apply(self, prompt: str, **kwargs: str) -> str:
# Has to be before the llama config
Expand Down Expand Up @@ -344,6 +355,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
return StableLMZephyr()
if re.search("stablecode-instruct", model_name):
return StableCode()
if re.search(r"Falcon3.*-Instruct", model_name):
return Falcon3()
if re.search(r"falcon.*-instruct", model_name):
return Falcon()
if re.search("Llama-2-7b-chat-hf-function-calling-v2", model_name):
Expand Down
59 changes: 59 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,65 @@ def test_against_original_smollm2(model_name, device, dtype):
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)

@torch.inference_mode()
@pytest.mark.parametrize("model_name", ("Falcon3-1B-Base", "Falcon3-7B-Base"))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
RunIf(min_cuda_gpus=1),
],
),
],
)
def test_against_hf_falcon3(model_name, device, dtype):
torch.set_default_dtype(dtype)

ours_config = Config.from_name(
model_name,
padded_vocab_size=10000,
n_layer=2,
n_head=8,
n_embd=32,
n_query_groups=2,
intermediate_size=86,
)
T = 5
theirs_config = LlamaConfig(
vocab_size=ours_config.padded_vocab_size,
hidden_size=ours_config.n_embd,
num_attention_heads=ours_config.n_head,
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
rms_norm_eps=ours_config.norm_eps,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


@RunIf(dynamo=True)
@torch.inference_mode()
Expand Down
5 changes: 5 additions & 0 deletions tests/test_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def test_tokenizer_against_hf(config):
# even though their config defines it, it's set as None in HF
assert isinstance(ours.bos_id, int)
assert theirs.bos_token_id is None
elif config.name.startswith("Falcon3"):
if isinstance(ours.bos_id, int):
assert theirs.bos_token_id is None
else:
assert ours.bos_id == theirs.bos_token_id == None
else:
assert ours.bos_id == theirs.bos_token_id

Expand Down
9 changes: 9 additions & 0 deletions tutorials/download_model_weights.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| Falcon 3 | 1B, 3B, 7B, 10B | TII UAE | [TII 2024](https://huggingface.co/blog/falcon3) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
Expand Down Expand Up @@ -233,6 +234,14 @@ tiiuae/falcon-40b
tiiuae/falcon-40b-instruct
tiiuae/falcon-7b
tiiuae/falcon-7b-instruct
tiiuae/Falcon3-1B-Base
tiiuae/Falcon3-1B-Instruct
tiiuae/Falcon3-3B-Base
tiiuae/Falcon3-3B-Instruct
tiiuae/Falcon3-7B-Base
tiiuae/Falcon3-7B-Instruct
tiiuae/Falcon3-10B-Base
tiiuae/Falcon3-10B-Instruct
TinyLlama/TinyLlama-1.1B-Chat-v1.0
TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
togethercomputer/LLaMA-2-7B-32K
Expand Down

0 comments on commit 1811ecc

Please sign in to comment.