Skip to content

Commit

Permalink
Half precision fixes (Lightning-AI#606)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Oct 24, 2023
1 parent 6ed6a15 commit 6178c7c
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 76 deletions.
17 changes: 5 additions & 12 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ def rope_cache(self, device: Optional[torch.device] = None) -> Tuple[torch.Tenso
return build_rope_cache(
seq_len=self.max_seq_length,
n_elem=self.config.rope_n_elem,
dtype=torch.get_default_dtype(),
device=device,
condense_ratio=self.config.rope_condense_ratio,
base=self.config.rope_base,
Expand Down Expand Up @@ -158,15 +157,15 @@ def forward(
h = self.attn(n_1, cos, sin, mask, input_pos)
if self.config.parallel_residual:
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
x = x + h + self.mlp(n_2)
x = self.mlp(n_2) + h + x
else:
if self.config.shared_attention_norm:
raise NotImplementedError(
"No checkpoint amongst the ones we support uses this configuration"
" (non-parallel residual and shared attention norm)."
)
x = x + h
x = x + self.mlp(self.norm_2(x))
x = h + x
x = self.mlp(self.norm_2(x)) + x
return x


Expand Down Expand Up @@ -308,7 +307,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
def build_rope_cache(
seq_len: int,
n_elem: int,
dtype: torch.dtype,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
Expand All @@ -320,20 +318,15 @@ def build_rope_cache(
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device) / n_elem))
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))

# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio

# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)

cos, sin = torch.cos(idx_theta), torch.sin(idx_theta)

# this is to mimic the behaviour of complex32, else we will get different results
if dtype in (torch.float16, torch.bfloat16, torch.int8):
return cos.half(), sin.half()
return cos, sin
return torch.cos(idx_theta), torch.sin(idx_theta)


def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
Expand Down
4 changes: 3 additions & 1 deletion lit_gpt/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
self.dim = dim

def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype
x = x.float()
# NOTE: the original RMSNorm paper implementation is not equivalent
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return self.weight * x_normed
return (self.weight * x_normed).to(dtype=dtype)

def reset_parameters(self) -> None:
torch.nn.init.ones_(self.weight)
126 changes: 64 additions & 62 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,28 @@ def restore_default_dtype():
@pytest.mark.parametrize("batch_size", (1, 3))
@pytest.mark.parametrize("n_embd", (16, 32))
@pytest.mark.parametrize("parallel_residual", (False, True))
@pytest.mark.parametrize("kv_cache", (False, True))
@pytest.mark.parametrize(
("device", "dtype"),
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
pytest.mark.xfail(raises=AssertionError, strict=True),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
torch.device("cuda"), torch.float16, marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
),
],
)
def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, kv_cache, device, dtype) -> None:
def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residual, device, dtype) -> None:
from transformers import GPTNeoXConfig, GPTNeoXForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_gpt_neox

batch_size = 3
torch.set_default_dtype(dtype)

ours_config = Config(
block_size=64,
vocab_size=100,
Expand All @@ -66,23 +66,21 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua
rotary_pct=ours_config.rotary_percentage,
vocab_size=ours_config.padded_vocab_size,
use_parallel_residual=ours_config.parallel_residual,
use_cache=kv_cache,
torch_dtype=dtype,
)

state_dict = {}
theirs_model = GPTNeoXForCausalLM(theirs_config).to(device)
# load the hf initialization into our model
copy_weights_gpt_neox(state_dict, theirs_model.state_dict())
ours_model = GPT(ours_config).to(device, dtype)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

token_sample = torch.randint(
0, ours_config.padded_vocab_size, size=(batch_size, ours_config.block_size), dtype=torch.int64, device=device
)

theirs = theirs_model(token_sample)["logits"]
ours = ours_model(token_sample).float() # HF converts logits to float
ours = ours_model(token_sample)
torch.testing.assert_close(ours, theirs)


Expand All @@ -99,21 +97,23 @@ def test_against_gpt_neox_model(rotary_pct, batch_size, n_embd, parallel_residua
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
pytest.mark.xfail(raises=AssertionError, strict=True),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
torch.device("cuda"), torch.float16, marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
),
],
)
def test_against_original_falcon_180b(kwargs, device, dtype):
def test_against_hf_falcon(kwargs, device, dtype):
from transformers.models.falcon import FalconConfig, FalconForCausalLM

from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_falcon

torch.set_default_dtype(dtype)

ours_config = Config.from_name(**kwargs)
theirs_config = FalconConfig(
hidden_size=ours_config.n_embd,
Expand All @@ -124,19 +124,18 @@ def test_against_original_falcon_180b(kwargs, device, dtype):
vocab_size=ours_config.padded_vocab_size,
bias=ours_config.bias,
new_decoder_architecture=True,
torch_dtype=dtype,
)

theirs_model = FalconForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_falcon(kwargs["name"], state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device, dtype)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
ours_y = ours_model(x).float() # HF converts logits to float
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"]
torch.testing.assert_close(ours_y, theirs_y)

Expand All @@ -147,12 +146,12 @@ def test_against_original_falcon_180b(kwargs, device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
pytest.mark.xfail(raises=AssertionError, strict=True),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
torch.device("cuda"), torch.float16, marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
),
],
)
Expand All @@ -163,6 +162,8 @@ def test_against_original_open_llama_3b(device, dtype):
from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama

torch.set_default_dtype(dtype)

ours_config = Config.from_name("open_llama_3b", n_layer=2, n_head=8, n_embd=32, intermediate_size=86)
T = 5
theirs_config = LlamaConfig(
Expand All @@ -171,22 +172,21 @@ def test_against_original_open_llama_3b(device, dtype):
num_hidden_layers=ours_config.n_layer,
intermediate_size=ours_config.intermediate_size,
max_position_embeddings=T,
torch_dtype=dtype,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device, dtype)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x).float() # HF converts logits to float
theirs_y = theirs_model(x)["logits"]
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


Expand All @@ -210,12 +210,12 @@ def test_against_original_open_llama_3b(device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
pytest.mark.xfail(raises=AssertionError, strict=True),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
torch.device("cuda"), torch.float16, marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
),
],
)
Expand All @@ -226,6 +226,8 @@ def test_against_hf_llama2(ours_kwargs, device, dtype):
from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama

torch.set_default_dtype(dtype)

ours_config = Config.from_name(
padded_vocab_size=10000, n_layer=2, n_head=8, n_embd=32, intermediate_size=86, **ours_kwargs
)
Expand All @@ -241,22 +243,21 @@ def test_against_hf_llama2(ours_kwargs, device, dtype):
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
attention_bias=ours_config.bias,
torch_dtype=dtype,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = LlamaForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device, dtype)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x).float() # HF converts logits to float
theirs_y = theirs_model(x)["logits"]
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


Expand All @@ -266,12 +267,10 @@ def test_against_hf_llama2(ours_kwargs, device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
pytest.mark.xfail(raises=AssertionError, strict=True),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
torch.device("cuda"), torch.float16, marks=[
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
),
],
)
Expand All @@ -285,6 +284,8 @@ def test_against_hf_phi(device, dtype):
from scripts.convert_hf_checkpoint import copy_weights_phi
from tests.original_phi_1_5 import MixFormerSequentialConfig, MixFormerSequentialForCausalLM

torch.set_default_dtype(dtype)

ours_config = Config.from_name(
"phi-1_5", padded_vocab_size=10000, n_layer=2, n_head=4, n_embd=256, rotary_percentage=0.5
)
Expand All @@ -304,14 +305,14 @@ def test_against_hf_phi(device, dtype):
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_phi(ours_config, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device, dtype)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x).float() # HF converts logits to float
theirs_y = theirs_model(x)["logits"]
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


Expand All @@ -324,12 +325,12 @@ def test_against_hf_phi(device, dtype):
[
(torch.device("cpu"), torch.float32),
pytest.param(
torch.device("cuda"),
torch.float16,
marks=[
pytest.mark.xfail(raises=AssertionError, strict=True),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA"),
],
torch.device("cuda"), torch.float16, marks=[
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
# is slightly different
pytest.mark.xfail(raises=AssertionError, strict=False),
pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA")
]
),
],
)
Expand All @@ -340,6 +341,8 @@ def test_against_hf_mistral(device, dtype):
from lit_gpt import GPT, Config
from scripts.convert_hf_checkpoint import copy_weights_hf_llama

torch.set_default_dtype(dtype)

ours_config = Config.from_name(
"Mistral-7B-Instruct-v0.1",
padded_vocab_size=10000,
Expand All @@ -360,22 +363,21 @@ def test_against_hf_mistral(device, dtype):
rms_norm_eps=1e-5,
num_key_value_heads=ours_config.n_query_groups,
rope_theta=ours_config.rope_base,
torch_dtype=dtype,
)
assert ours_config.intermediate_size == theirs_config.intermediate_size

theirs_model = MistralForCausalLM(theirs_config).to(device)
theirs_state_dict = theirs_model.state_dict()
state_dict = {}
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
ours_model = GPT(ours_config).to(device, dtype)
ours_model = GPT(ours_config).to(device)
ours_model.load_state_dict(state_dict)

# test end to end
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
assert x.size(1) == T
ours_y = ours_model(x).float()
theirs_y = theirs_model(x)["logits"]
ours_y = ours_model(x)
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
torch.testing.assert_close(ours_y, theirs_y)


Expand Down
Loading

0 comments on commit 6178c7c

Please sign in to comment.