Skip to content

Commit

Permalink
Add llama config error check
Browse files Browse the repository at this point in the history
  • Loading branch information
gordicaleksa committed Aug 8, 2024
1 parent b25e325 commit b7c98c9
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ class LlamaConfig:
max_gen_batch_size: int = 4
use_kv: bool = True

def __init__(self, **kwargs):
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
assert self.n_kv_head <= self.n_head
assert self.n_head % self.n_kv_head == 0
assert self.n_embd % self.n_head == 0

class LLaMA(nn.Module):

def __init__(self, config):
Expand Down

0 comments on commit b7c98c9

Please sign in to comment.