Skip to content

Commit

Permalink
Refactor language modeling example
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke committed Jan 17, 2017
1 parent c053f31 commit d5a75ba
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 212 deletions.
40 changes: 20 additions & 20 deletions word_language_model/README.md
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
# Word-level language modeling RNN

This example trains a multi-layer RNN (Elman, GRU, or LSTM) on a language modeling task.
By default, the training script uses the PTB dataset, provided.
By default, the training script uses the PTB dataset, provided.
The trained model can then be used by the generate script to generate new text.

```bash
python main.py -cuda # Train an LSTM on ptb with cuda (cuDNN). Should reach perplexity of 116
python generate.py # Generate samples from the trained LSTM model.
python main.py --cuda # Train an LSTM on ptb with cuda (cuDNN). Should reach perplexity of 113
python generate.py # Generate samples from the trained LSTM model.
```

The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`) which will automatically use the cuDNN backend if run on CUDA with cuDNN installed.
The model uses the `nn.RNN` module (and its sister modules `nn.GRU` and `nn.LSTM`)
which will automatically use the cuDNN backend if run on CUDA with cuDNN installed.

The `main.py` script accepts the following arguments:

```bash
optional arguments:
-h, --help show this help message and exit
-data DATA Location of the data corpus
-model MODEL Type of recurrent net. RNN_TANH, RNN_RELU, LSTM, or
GRU.
-emsize EMSIZE Size of word embeddings
-nhid NHID Number of hidden units per layer.
-nlayers NLAYERS Number of layers.
-lr LR Initial learning rate.
-clip CLIP Gradient clipping.
-maxepoch MAXEPOCH Upper epoch limit.
-batchsize BATCHSIZE Batch size.
-bptt BPTT Sequence length.
-seed SEED Random seed.
-cuda Use CUDA.
-reportint REPORTINT Report interval.
-save SAVE Path to save the final model.
-h, --help show this help message and exit
--data DATA location of the data corpus
--model MODEL type of recurrent net (RNN_TANH, RNN_RELU, LSTM, GRU)
--emsize EMSIZE size of word embeddings
--nhid NHID humber of hidden units per layer
--nlayers NLAYERS number of layers
--lr LR initial learning rate
--clip CLIP gradient clipping
--epochs EPOCHS upper epoch limit
--batch-size N batch size
--bptt BPTT sequence length
--seed SEED random seed
--cuda use CUDA
--log-interval N report interval
--save SAVE path to save the final model
```
43 changes: 19 additions & 24 deletions word_language_model/data.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,48 @@
########################################
# Data Fetching Script for PTB
########################################

import os
import torch
import os.path

class Dictionary(object):
def __init__(self):
self.word2idx = {}
self.idx2word = []

def addword(self, word):
def add_word(self, word):
if word not in self.word2idx:
self.idx2word.append(word)
self.word2idx[word] = len(self.idx2word) - 1

return self.word2idx[word]

def ntokens(self):
def __len__(self):
return len(self.idx2word)


class Corpus(object):
def __init__(self, path):
self.dic = Dictionary()
self.train=self._loadfile(os.path.join(path, 'train.txt'))
self.valid=self._loadfile(os.path.join(path, 'valid.txt'))
self.test =self._loadfile(os.path.join(path, 'test.txt'))

# | Tokenize a text file.
def _loadfile(self, path):
# Read words from file.
assert(os.path.exists(path))
tokens = 0
self.dictionary = Dictionary()
self.train = self.tokenize(os.path.join(path, 'train.txt'))
self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
self.test = self.tokenize(os.path.join(path, 'test.txt'))

def tokenize(self, path):
"""Tokenizes a text file."""
assert os.path.exists(path)
# Add words to the dictionary
with open(path, 'r') as f:
tokens = 0
for line in f:
words = line.split() + ['<eos>']
tokens += len(words)
for word in words:
self.dic.addword(word)
tokens += 1

self.dictionary.add_word(word)

# Tokenize file content
with open(path, 'r') as f:
ids = torch.LongTensor(tokens)
token = 0
for line in f:
words = line.split() + ['<eos>']
for word in words:
ids[token] = self.dic.word2idx[word]
ids[token] = self.dictionary.word2idx[word]
token += 1

# Final dataset.

return ids
67 changes: 36 additions & 31 deletions word_language_model/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,34 @@
parser = argparse.ArgumentParser(description='PyTorch PTB Language Model')

# Model parameters.
parser.add_argument('-data' , type=str, default='./data/penn', help='Location of the data corpus' )
parser.add_argument('-checkpoint', type=str, default='./model.pt' , help='Checkpoint file path' )
parser.add_argument('-outf' , type=str, default='generated.out', help='Output file for generated text.' )
parser.add_argument('-nwords' , type=int, default='1000' , help='Number of words of text to generate' )
parser.add_argument('-seed' , type=int, default=1111 , help='Random seed.' )
parser.add_argument('-cuda' , action='store_true' , help='Use CUDA.' )
parser.add_argument('-temperature', type=float, default=1.0 , help='Temperature. Higher will increase diversity')
parser.add_argument('-reportinterval', type=int, default=100 , help='Reporting interval' )
parser.add_argument('--data', type=str, default='./data/penn',
help='location of the data corpus')
parser.add_argument('--checkpoint', type=str, default='./model.pt',
help='model checkpoint to use')
parser.add_argument('--outf', type=str, default='generated.txt',
help='output file for generated text')
parser.add_argument('--words', type=int, default='1000',
help='number of words to generate')
parser.add_argument('--seed', type=int, default=1111,
help='random seed')
parser.add_argument('--cuda', action='store_true',
help='use CUDA')
parser.add_argument('--temperature', type=float, default=1.0,
help='temperature - higher will increase diversity')
parser.add_argument('--log-interval', type=int, default=100,
help='reporting interval')
args = parser.parse_args()

# Set the random seed manually for reproducibility.
torch.manual_seed(args.seed)
# If the GPU is enabled, do some plumbing.
if torch.cuda.is_available():
if not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
else:
torch.cuda.manual_seed(args.seed)

if torch.cuda.is_available() and not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with -cuda")
if args.temperature < 1e-3:
parser.error("--temperature has to be greater or equal 1e-3")

with open(args.checkpoint, 'rb') as f:
model = torch.load(f)
Expand All @@ -44,28 +56,21 @@
model.cpu()

corpus = data.Corpus(args.data)
ntokens = corpus.dic.ntokens()

hidden = model.initHidden(1)

input = torch.LongTensor(1,1).fill_(int(math.floor(torch.rand(1)[0] * ntokens)))
ntokens = len(corpus.dictionary)
hidden = model.init_hidden(1)
input = Variable(torch.rand(1, 1).mul(ntokens).long(), volatile=True)
if args.cuda:
input = input.cuda()
input.data = input.data.cuda()

temperature = max(args.temperature, 1e-3)
with open(args.outf, 'w') as outf:
for i in range(args.nwords):

output, hidden = model(Variable(input, volatile=True), hidden)
gen = torch.multinomial(output[0].data.div(temperature).exp().cpu(), 1)[0][0] # FIXME: multinomial is only for CPU
input.fill_(gen)
word = corpus.dic.idx2word[gen]
outf.write(word)
for i in range(args.words):
output, hidden = model(input, hidden)
word_weights = output.squeeze().data.div(args.temperature).exp().cpu()
word_idx = torch.multinomial(word_weights, 1)[0]
input.data.fill_(word_idx)
word = corpus.dictionary.idx2word[word_idx]

if i % 20 == 19:
outf.write("\n")
else:
outf.write(" ")
outf.write(word + ('\n' if i % 20 == 19 else ' '))

if i % args.reportinterval == 0:
print('| Generated {}/{} words'.format(i, args.nwords))
if i % args.log_interval == 0:
print('| Generated {}/{} words'.format(i, args.words))
Loading

0 comments on commit d5a75ba

Please sign in to comment.