Skip to content

Commit

Permalink
feat: training on dataloader lite
Browse files Browse the repository at this point in the history
  • Loading branch information
sohamtiwari3120 committed Sep 16, 2024
1 parent 9e0238f commit ad1f00d
Showing 1 changed file with 45 additions and 12 deletions.
57 changes: 45 additions & 12 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,42 @@ def get_device():
return device

# ---------------------------------------------------------------------------------------
import tiktoken

class DataLoaderLite:
def __init__(self, B, T) -> None:
self.B = B
self.T = T

# at init load tokens from disc and store into memory
with open('input.txt', 'r') as f:
text = f.read()

enc = tiktoken.get_encoding('gpt2')
tokens = enc.encode(text)
self.tokens = torch.tensor(tokens)
print(f"Number of characters in text {len(text)} characters")
print(f"loaded {len(self.tokens)} tokens")
print(f"1 epoch = {len(self.tokens) // (B * T)} batches")

# state
self.current_position = 0


def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position : self.current_position + (B * T) + 1]
x = buf[:-1].view(B, T) # inputs
y = buf[1:].view(B, T) # targets

# advance the position in the tensor
self.current_position += B * T

if self.current_position + (B * T + 1) > len(self.tokens):
self.current_position = 0
return x, y


def test_code():
seed = 1337
torch.manual_seed(seed)
Expand All @@ -226,18 +261,8 @@ def test_code():
device = get_device()

# get a data batch
import tiktoken
enc = tiktoken.get_encoding('gpt2')
with open('input.txt', 'r') as f:
text = f.read()

text = text[:1000]
tokens = enc.encode(text)

B, T = 4, 32
buf = torch.tensor(tokens[:B*T + 1]).to(device)
x = buf[:-1].view(B, T)
y = buf[1:].view(B, T)
train_loader = DataLoaderLite(B, T)

# get logits
model = GPT(GPTConfig())
Expand All @@ -246,12 +271,20 @@ def test_code():

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(50):
x, y = train_loader.next_batch()
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
logits, loss = model(x, y)
loss.backward() #.backward() always does += on existing gradients, hence important to zero grad (unless grad accumulation)
optimizer.step()
print(f"step {i}, loss: {loss.item()}")


# NOTE: IMPORTANT INSIGHT
# 1. We expect the loss to still decrease on the above small dataset.
# 2. Two things to note: 1) Our dataset is very biased, and covers only a very small portion of the 50,257 tokens
# 2) Hence when training, the model would just try to eliminate/"forget" the importance of the other tokens that never occur in the dataset
# by for example driving the bias for these terms to -inf. This is the cause behind the easy gains that will be made .
# 3. Compression ratio is 3:1 (3 characters ~= 1 token)

if __name__ == "__main__":
test_code()

0 comments on commit ad1f00d

Please sign in to comment.