Skip to content

Commit

Permalink
add an RNN and a GRU language model
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed Aug 22, 2022
1 parent 6694b67 commit b697f43
Showing 1 changed file with 108 additions and 8 deletions.
116 changes: 108 additions & 8 deletions makemore.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def forward(self, x):
x = x + self.mlpf(self.ln_2(x))
return x

class GPT(nn.Module):
""" GPT Language Model """
class Transformer(nn.Module):
""" Transformer Language Model, exactly as seen in GPT-2 """

def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -155,6 +155,99 @@ def forward(self, idx, targets=None):

return logits, loss

# -----------------------------------------------------------------------------
"""
Recurrent Neural Net language model: either a vanilla RNN recurrence or a GRU.
Did not implement an LSTM because its API is a bit more annoying as it has
both a hidden state and a cell state, but it's very similar to GRU and in
practice works just as well.
"""

class RNNCell(nn.Module):
"""
the job of a 'Cell' is to:
take input at current time step x_{t} and the hidden state at the
previous time step h_{t-1} and return the resulting hidden state
h_{t} at the current timestep
"""
def __init__(self, config):
super().__init__()
self.xh_to_h = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)

def forward(self, xt, hprev):
xh = torch.cat([xt, hprev], dim=1)
ht = F.tanh(self.xh_to_h(xh))
return ht

class GRUCell(nn.Module):
"""
same job as RNN cell, but a bit more complicated recurrence formula
that makes the GRU more expressive and easier to optimize.
"""
def __init__(self, config):
super().__init__()
# input, forget, output, gate
self.xh_to_z = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
self.xh_to_r = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
self.xh_to_hbar = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)

def forward(self, xt, hprev):
# first use the reset gate to wipe some channels of the hidden state to zero
xh = torch.cat([xt, hprev], dim=1)
r = F.sigmoid(self.xh_to_r(xh))
hprev_reset = r * hprev
# calculate the candidate new hidden state hbar
xhr = torch.cat([xt, hprev_reset], dim=1)
hbar = F.tanh(self.xh_to_hbar(xhr))
# calculate the switch gate that determines if each channel should be updated at all
z = F.sigmoid(self.xh_to_z(xh))
# blend the previous hidden state and the new candidate hidden state
ht = (1 - z) * hprev + z * hbar
return ht

class RNN(nn.Module):

def __init__(self, config, cell_type):
super().__init__()
self.block_size = config.block_size
self.vocab_size = config.vocab_size
self.start = nn.Parameter(torch.zeros(1, config.n_embd2)) # the starting hidden state
self.wte = nn.Embedding(config.vocab_size, config.n_embd) # token embeddings table
if cell_type == 'rnn':
self.cell = RNNCell(config)
elif cell_type == 'gru':
self.cell = GRUCell(config)
self.lm_head = nn.Linear(config.n_embd2, self.vocab_size)

def get_block_size(self):
return self.block_size

def forward(self, idx, targets=None):
device = idx.device
b, t = idx.size()

# embed all the integers up front and all at once for efficiency
emb = self.wte(idx) # (b, t, n_embd)

# sequentially iterate over the inputs and update the RNN state each tick
hprev = self.start.expand((b, -1)) # expand out the batch dimension
hiddens = []
for i in range(t):
xt = emb[:, i, :] # (b, n_embd)
ht = self.cell(xt, hprev) # (b, n_embd2)
hiddens.append(ht)

# decode the outputs
hidden = torch.stack(hiddens, 1) # (b, t, n_embd2)
logits = self.lm_head(hidden)

# if we are given some desired targets also calculate the loss
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)

return logits, loss

# -----------------------------------------------------------------------------
# MLP language model

Expand Down Expand Up @@ -418,7 +511,7 @@ def next(self):
# sampling
parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k")
# model
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|transformer")
parser.add_argument('--type', type=str, default='bigram', help="model class type to use, bigram|mlp|rnn|gru|transformer")
parser.add_argument('--n-layer', type=int, default=4, help="number of layers")
parser.add_argument('--n-head', type=int, default=4, help="number of heads (in a transformer)")
parser.add_argument('--n-embd', type=int, default=64, help="number of feature channels in the model")
Expand Down Expand Up @@ -446,11 +539,18 @@ def next(self):
config = ModelConfig(vocab_size=vocab_size, block_size=block_size,
n_layer=args.n_layer, n_head=args.n_head,
n_embd=args.n_embd, n_embd2=args.n_embd2)
model = {
'transformer': GPT,
'bigram': Bigram,
'mlp': MLP,
}[args.type](config)
if args.type == 'transformer':
model = Transformer(config)
elif args.type == 'bigram':
model = Bigram(config)
elif args.type == 'mlp':
model = MLP(config)
elif args.type == 'rnn':
model = RNN(config, cell_type='rnn')
elif args.type == 'gru':
model = RNN(config, cell_type='gru')
else:
raise ValueError(f'model type {args.type} is not recognized')
model.to(args.device)
print(f"model #params: {sum(p.numel() for p in model.parameters())}")
if args.resume or args.sample_only: # note: if we sample-only then we also assume we are resuming
Expand Down

0 comments on commit b697f43

Please sign in to comment.