Skip to content

Commit

Permalink
generalize makemore into other types of language models, and add bigr…
Browse files Browse the repository at this point in the history
…am LM and an MLP LM
  • Loading branch information
karpathy committed Aug 22, 2022
1 parent 50617fa commit 6694b67
Showing 1 changed file with 134 additions and 42 deletions.
176 changes: 134 additions & 42 deletions makemore.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,19 @@
from torch.utils.tensorboard import SummaryWriter

# -----------------------------------------------------------------------------
# GPT model definition

@dataclass
class GPTConfig:
# size of the model
class ModelConfig:
block_size: int = None # length of the input sequences of integers
vocab_size: int = None # the input integers are in range [0 .. vocab_size -1]
# parameters below control the sizes of each model slightly differently
n_layer: int = 4
n_head: int = 4
n_embd: int = 64
vocab_size: int = None
block_size: int = None
n_embd2: int = 64
n_head: int = 4

# -----------------------------------------------------------------------------
# Transformer Language Model (*exactly* as used in GPT-2)

class NewGELU(nn.Module):
"""
Expand Down Expand Up @@ -127,6 +130,9 @@ def __init__(self, config):
n_params = sum(p.numel() for p in self.transformer.parameters())
print("number of parameters: %.2fM" % (n_params/1e6,))

def get_block_size(self):
return self.block_size

def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()
Expand All @@ -149,45 +155,124 @@ def forward(self, idx, targets=None):

return logits, loss

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = self(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

return idx
# -----------------------------------------------------------------------------
# MLP language model

class MLP(nn.Module):
"""
takes the previous block_size tokens, encodes them with a lookup table,
concatenates the vectors and predicts the next token with an MLP.
Reference:
Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
"""

def __init__(self, config):
super().__init__()
self.block_size = config.block_size
self.vocab_size = config.vocab_size
self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table
# +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
# before the beginning of the input sequence
self.mlp = nn.Sequential(
nn.Linear(self.block_size * config.n_embd, config.n_embd2), # TODO: option to vary this
nn.Tanh(),
nn.Linear(config.n_embd2, self.vocab_size)
)

def get_block_size(self):
return self.block_size

def forward(self, idx, targets=None):

# gather the word embeddings of the previous 3 words
embs = []
for k in range(self.block_size):
tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
idx = torch.roll(idx, 1, 1)
idx[:, 0] = self.vocab_size # special <BLANK> token
embs.append(tok_emb)

# concat all of the embeddings together and pass through an MLP
x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
logits = self.mlp(x)

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

return logits, loss

# -----------------------------------------------------------------------------
# Bigram language model

class Bigram(nn.Module):
"""
Bigram Language Model 'neural net', simply a lookup table of logits for the
next character given a previous character.
"""

def __init__(self, config):
super().__init__()
n = config.vocab_size
self.logits = nn.Parameter(torch.zeros((n, n)))

def get_block_size(self):
return 1 # this model only needs one previous character to predict the next

def forward(self, idx, targets=None):

# 'forward pass', lol
logits = self.logits[idx]

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

return logits, loss

# -----------------------------------------------------------------------------
# helper functions for evaluating and sampling from the model

@torch.no_grad()
def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
"""
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
block_size = model.get_block_size()
for _ in range(max_new_tokens):
# if the sequence context is growing too long we must crop it at block_size
idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
# forward the model to get the logits for the index in the sequence
logits, _ = model(idx_cond)
# pluck the logits at the final step and scale by desired temperature
logits = logits[:, -1, :] / temperature
# optionally crop the logits to only the top k options
if top_k is not None:
v, _ = torch.topk(logits, top_k)
logits[logits < v[:, [-1]]] = -float('Inf')
# apply softmax to convert logits to (normalized) probabilities
probs = F.softmax(logits, dim=-1)
# either sample from the distribution or take the most likely element
if do_sample:
idx_next = torch.multinomial(probs, num_samples=1)
else:
_, idx_next = torch.topk(probs, k=1, dim=-1)
# append sampled index to the running sequence and continue
idx = torch.cat((idx, idx_next), dim=1)

return idx

def print_samples(num=10):
""" samples from the model and pretty prints the decoded samples """
X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
top_k = args.top_k if args.top_k != -1 else None
steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
X_samp = model.generate(X_init, steps, top_k=top_k, do_sample=True).to('cpu')
X_samp = generate(model, X_init, steps, top_k=top_k, do_sample=True).to('cpu')
train_samples, test_samples, new_samples = [], [], []
for i in range(X_samp.size(0)):
# get the i'th row of sampled integers, as python list
Expand Down Expand Up @@ -333,9 +418,11 @@ def next(self):
# sampling
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
# model
parser.add_argument('--n-layer', type=int, default=4, help="number of layers in the transformer")
parser.add_argument('--n-head', type=int, default=4, help="number of heads in the transformer")
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the transformer")
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|transformer")
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
parser.add_argument('--n-embd2', type=int, default=64, help="number of feature channels elsewhere in the model")
# optimization
parser.add_argument('--batch-size', '-b', type=int, default=32, help="batch size during optimization")
parser.add_argument('--learning-rate', '-l', type=float, default=5e-4, help="learning rate")
Expand All @@ -356,9 +443,14 @@ def next(self):
print(f"dataset determined that: {vocab_size=}, {block_size=}")

# init model
config = GPTConfig(vocab_size=vocab_size, block_size=block_size,
n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd)
model = GPT(config)
config = ModelConfig(vocab_size=vocab_size, block_size=block_size,
n_layer=args.n_layer, n_head=args.n_head,
n_embd=args.n_embd, n_embd2=args.n_embd2)
model = {
'transformer': GPT,
'bigram': Bigram,
'mlp': MLP,
}[args.type](config)
model.to(args.device)
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming
Expand Down

0 comments on commit 6694b67

Please sign in to comment.