Skip to content

Commit

Permalink
train code close to done
Browse files Browse the repository at this point in the history
  • Loading branch information
adamlerer committed Dec 23, 2016
1 parent 45bd1a9 commit 286a8cb
Show file tree
Hide file tree
Showing 10 changed files with 876 additions and 1,096 deletions.
Binary file removed OpenNMT/.preprocess.lua.swp
Binary file not shown.
223 changes: 87 additions & 136 deletions OpenNMT/onmt/Models.py
Original file line number Diff line number Diff line change
@@ -1,136 +1,87 @@
local function buildEncoder(opt, dicts)
local inputNetwork = onmt.WordEmbedding.new(dicts.words:size(), -- vocab size
opt.word_vec_size,
opt.pre_word_vecs_enc,
opt.fix_word_vecs_enc)

local inputSize = opt.word_vec_size

-- Sequences with features.
if #dicts.features > 0 then
local srcFeatEmbedding = onmt.FeaturesEmbedding.new(dicts.features,
opt.feat_vec_exponent,
opt.feat_vec_size,
opt.feat_merge)

inputNetwork = nn.Sequential()
:add(nn.ParallelTable()
:add(inputNetwork)
:add(srcFeatEmbedding))
:add(nn.JoinTable(2))

inputSize = inputSize + srcFeatEmbedding.outputSize
end

if opt.brnn then
-- Compute rnn hidden size depending on hidden states merge action.
local rnnSize = opt.rnn_size
if opt.brnn_merge == 'concat' then
if opt.rnn_size % 2 ~= 0 then
error('in concat mode, rnn_size must be divisible by 2')
end
rnnSize = rnnSize / 2
elseif opt.brnn_merge == 'sum' then
rnnSize = rnnSize
else
error('invalid merge action ' .. opt.brnn_merge)
end

local rnn = onmt.LSTM.new(opt.layers, inputSize, rnnSize, opt.dropout, opt.residual)

return onmt.BiEncoder.new(inputNetwork, rnn, opt.brnn_merge)
else
local rnn = onmt.LSTM.new(opt.layers, inputSize, opt.rnn_size, opt.dropout, opt.residual)

return onmt.Encoder.new(inputNetwork, rnn)
end
end

local function buildDecoder(opt, dicts, verbose)
local inputNetwork = onmt.WordEmbedding.new(dicts.words:size(), -- vocab size
opt.word_vec_size,
opt.pre_word_vecs_dec,
opt.fix_word_vecs_dec)

local inputSize = opt.word_vec_size

local generator

-- Sequences with features.
if #dicts.features > 0 then
local tgtFeatEmbedding = onmt.FeaturesEmbedding.new(dicts.features,
opt.feat_vec_exponent,
opt.feat_vec_size,
opt.feat_merge)

inputNetwork = nn.Sequential()
:add(nn.ParallelTable()
:add(inputNetwork)
:add(tgtFeatEmbedding))
:add(nn.JoinTable(2))

inputSize = inputSize + tgtFeatEmbedding.outputSize

generator = onmt.FeaturesGenerator.new(opt.rnn_size, dicts.words:size(), dicts.features)
else
generator = onmt.Generator.new(opt.rnn_size, dicts.words:size())
end

if opt.input_feed == 1 then
if verbose then
print(" * using input feeding")
end
inputSize = inputSize + opt.rnn_size
end

local rnn = onmt.LSTM.new(opt.layers, inputSize, opt.rnn_size, opt.dropout, opt.residual)

return onmt.Decoder.new(inputNetwork, rnn, generator, opt.input_feed == 1)
end

--[[ This is useful when training from a model in parallel mode: each thread must own its model. ]]
local function clonePretrained(model)
local clone = {}

for k, v in pairs(model) do
if k == 'modules' then
clone.modules = {}
for i = 1, #v do
table.insert(clone.modules, onmt.utils.Tensor.deepClone(v[i]))
end
else
clone[k] = v
end
end

return clone
end

local function loadEncoder(pretrained, clone)
local brnn = #pretrained.modules == 2

if clone then
pretrained = clonePretrained(pretrained)
end

if brnn then
return onmt.BiEncoder.load(pretrained)
else
return onmt.Encoder.load(pretrained)
end
end

local function loadDecoder(pretrained, clone)
if clone then
pretrained = clonePretrained(pretrained)
end

return onmt.Decoder.load(pretrained)
end

return {
buildEncoder = buildEncoder,
buildDecoder = buildDecoder,
loadEncoder = loadEncoder,
loadDecoder = loadDecoder
}
def _makeFeatEmbedder(opt, dicts):
return onmt.FeaturesEmbedding(dicts.features,
opt.feat_vec_exponent,
opt.feat_vec_size,
opt.feat_merge)


class Encoder(nn.Container):

def __init__(self, opt, dicts):
input_size = opt.word_vec_size
feat_lut = None
# Sequences with features.
if len(dicts.features) > 0:
feat_lut = _makeFeatEmbedder(opt, dicts)
inputSize = inputSize + feat_lut.outputSize

super(Encoder, self).__init__(
word_lut=nn.LookupTable(dicts.words.size(), opt.word_vec_size)),
rnn=nn.LSTM(inputSize, opt.rnnSize,
num_layers=opt.layers,
dropout=opt.dropout,
bidirectional=opt.brnn)
)

if opt.pre_word_vecs_enc is not None:
pretrained = torch.load(opt.pre_word_vecs_enc)
self.word_lut.weight.copy_(pretrained)

self.has_features = feat_lut is not None
if self.has_features:
self.add_module('feat_lut', feat_lut)

def forward(self, input, hidden):
if self.has_features:
word_emb = self.word_lut(input[0])
feat_emb = self.feat_lut(input[1])
emb = torch.cat([word_emb, feat_emb], 1)
else:
emb = self.word_lut(input)

outputs, next_hidden = self.rnn(input, hidden)
return outputs, next_hidden

class Decoder(nn.Container):

def __init__(self, opt, dicts):
input_size = opt.word_vec_size
feat_lut = None
# Sequences with features.
if len(dicts.features) > 0:
feat_lut = _makeFeatEmbedder(opt, dicts)
inputSize = inputSize + feat_lut.outputSize

super(Decoder, self).__init__(
word_lut=nn.LookupTable(dicts.words.size(), opt.word_vec_size)),
rnn=nn.LSTM(inputSize, opt.rnnSize,
num_layers=opt.layers,
dropout=opt.dropout),
attn=GlobalAttention(opt.rnnSize),
dropout=nn.Dropout(opt.dropout)
)

if opt.pre_word_vecs_enc is not None:
pretrained = torch.load(opt.pre_word_vecs_dec)
self.word_lut.weight.copy_(pretrained)

self.has_features = feat_lut is not None
if self.has_features:
self.add_module('feat_lut', feat_lut)

def forward(self, input, hidden):
if self.has_features:
word_emb = self.word_lut(input[0])
feat_emb = self.feat_lut(input[1])
emb = torch.cat([word_emb, feat_emb], 1)
else:
emb = self.word_lut(input)

if self.input_feed:
emb = torch.cat([emb, input_feed], 1) # 1 step

outputs, next_hidden = self.rnn(input, hidden)

attn = self.attn(outputs, context) # FIXME: per timestep?
attn = self.dropout(attn)
return attn, next_hidden
Loading

0 comments on commit 286a8cb

Please sign in to comment.