forked from pytorch/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
876 additions
and
1,096 deletions.
There are no files selected for viewing
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,136 +1,87 @@ | ||
local function buildEncoder(opt, dicts) | ||
local inputNetwork = onmt.WordEmbedding.new(dicts.words:size(), -- vocab size | ||
opt.word_vec_size, | ||
opt.pre_word_vecs_enc, | ||
opt.fix_word_vecs_enc) | ||
|
||
local inputSize = opt.word_vec_size | ||
|
||
-- Sequences with features. | ||
if #dicts.features > 0 then | ||
local srcFeatEmbedding = onmt.FeaturesEmbedding.new(dicts.features, | ||
opt.feat_vec_exponent, | ||
opt.feat_vec_size, | ||
opt.feat_merge) | ||
|
||
inputNetwork = nn.Sequential() | ||
:add(nn.ParallelTable() | ||
:add(inputNetwork) | ||
:add(srcFeatEmbedding)) | ||
:add(nn.JoinTable(2)) | ||
|
||
inputSize = inputSize + srcFeatEmbedding.outputSize | ||
end | ||
|
||
if opt.brnn then | ||
-- Compute rnn hidden size depending on hidden states merge action. | ||
local rnnSize = opt.rnn_size | ||
if opt.brnn_merge == 'concat' then | ||
if opt.rnn_size % 2 ~= 0 then | ||
error('in concat mode, rnn_size must be divisible by 2') | ||
end | ||
rnnSize = rnnSize / 2 | ||
elseif opt.brnn_merge == 'sum' then | ||
rnnSize = rnnSize | ||
else | ||
error('invalid merge action ' .. opt.brnn_merge) | ||
end | ||
|
||
local rnn = onmt.LSTM.new(opt.layers, inputSize, rnnSize, opt.dropout, opt.residual) | ||
|
||
return onmt.BiEncoder.new(inputNetwork, rnn, opt.brnn_merge) | ||
else | ||
local rnn = onmt.LSTM.new(opt.layers, inputSize, opt.rnn_size, opt.dropout, opt.residual) | ||
|
||
return onmt.Encoder.new(inputNetwork, rnn) | ||
end | ||
end | ||
|
||
local function buildDecoder(opt, dicts, verbose) | ||
local inputNetwork = onmt.WordEmbedding.new(dicts.words:size(), -- vocab size | ||
opt.word_vec_size, | ||
opt.pre_word_vecs_dec, | ||
opt.fix_word_vecs_dec) | ||
|
||
local inputSize = opt.word_vec_size | ||
|
||
local generator | ||
|
||
-- Sequences with features. | ||
if #dicts.features > 0 then | ||
local tgtFeatEmbedding = onmt.FeaturesEmbedding.new(dicts.features, | ||
opt.feat_vec_exponent, | ||
opt.feat_vec_size, | ||
opt.feat_merge) | ||
|
||
inputNetwork = nn.Sequential() | ||
:add(nn.ParallelTable() | ||
:add(inputNetwork) | ||
:add(tgtFeatEmbedding)) | ||
:add(nn.JoinTable(2)) | ||
|
||
inputSize = inputSize + tgtFeatEmbedding.outputSize | ||
|
||
generator = onmt.FeaturesGenerator.new(opt.rnn_size, dicts.words:size(), dicts.features) | ||
else | ||
generator = onmt.Generator.new(opt.rnn_size, dicts.words:size()) | ||
end | ||
|
||
if opt.input_feed == 1 then | ||
if verbose then | ||
print(" * using input feeding") | ||
end | ||
inputSize = inputSize + opt.rnn_size | ||
end | ||
|
||
local rnn = onmt.LSTM.new(opt.layers, inputSize, opt.rnn_size, opt.dropout, opt.residual) | ||
|
||
return onmt.Decoder.new(inputNetwork, rnn, generator, opt.input_feed == 1) | ||
end | ||
|
||
--[[ This is useful when training from a model in parallel mode: each thread must own its model. ]] | ||
local function clonePretrained(model) | ||
local clone = {} | ||
|
||
for k, v in pairs(model) do | ||
if k == 'modules' then | ||
clone.modules = {} | ||
for i = 1, #v do | ||
table.insert(clone.modules, onmt.utils.Tensor.deepClone(v[i])) | ||
end | ||
else | ||
clone[k] = v | ||
end | ||
end | ||
|
||
return clone | ||
end | ||
|
||
local function loadEncoder(pretrained, clone) | ||
local brnn = #pretrained.modules == 2 | ||
|
||
if clone then | ||
pretrained = clonePretrained(pretrained) | ||
end | ||
|
||
if brnn then | ||
return onmt.BiEncoder.load(pretrained) | ||
else | ||
return onmt.Encoder.load(pretrained) | ||
end | ||
end | ||
|
||
local function loadDecoder(pretrained, clone) | ||
if clone then | ||
pretrained = clonePretrained(pretrained) | ||
end | ||
|
||
return onmt.Decoder.load(pretrained) | ||
end | ||
|
||
return { | ||
buildEncoder = buildEncoder, | ||
buildDecoder = buildDecoder, | ||
loadEncoder = loadEncoder, | ||
loadDecoder = loadDecoder | ||
} | ||
def _makeFeatEmbedder(opt, dicts): | ||
return onmt.FeaturesEmbedding(dicts.features, | ||
opt.feat_vec_exponent, | ||
opt.feat_vec_size, | ||
opt.feat_merge) | ||
|
||
|
||
class Encoder(nn.Container): | ||
|
||
def __init__(self, opt, dicts): | ||
input_size = opt.word_vec_size | ||
feat_lut = None | ||
# Sequences with features. | ||
if len(dicts.features) > 0: | ||
feat_lut = _makeFeatEmbedder(opt, dicts) | ||
inputSize = inputSize + feat_lut.outputSize | ||
|
||
super(Encoder, self).__init__( | ||
word_lut=nn.LookupTable(dicts.words.size(), opt.word_vec_size)), | ||
rnn=nn.LSTM(inputSize, opt.rnnSize, | ||
num_layers=opt.layers, | ||
dropout=opt.dropout, | ||
bidirectional=opt.brnn) | ||
) | ||
|
||
if opt.pre_word_vecs_enc is not None: | ||
pretrained = torch.load(opt.pre_word_vecs_enc) | ||
self.word_lut.weight.copy_(pretrained) | ||
|
||
self.has_features = feat_lut is not None | ||
if self.has_features: | ||
self.add_module('feat_lut', feat_lut) | ||
|
||
def forward(self, input, hidden): | ||
if self.has_features: | ||
word_emb = self.word_lut(input[0]) | ||
feat_emb = self.feat_lut(input[1]) | ||
emb = torch.cat([word_emb, feat_emb], 1) | ||
else: | ||
emb = self.word_lut(input) | ||
|
||
outputs, next_hidden = self.rnn(input, hidden) | ||
return outputs, next_hidden | ||
|
||
class Decoder(nn.Container): | ||
|
||
def __init__(self, opt, dicts): | ||
input_size = opt.word_vec_size | ||
feat_lut = None | ||
# Sequences with features. | ||
if len(dicts.features) > 0: | ||
feat_lut = _makeFeatEmbedder(opt, dicts) | ||
inputSize = inputSize + feat_lut.outputSize | ||
|
||
super(Decoder, self).__init__( | ||
word_lut=nn.LookupTable(dicts.words.size(), opt.word_vec_size)), | ||
rnn=nn.LSTM(inputSize, opt.rnnSize, | ||
num_layers=opt.layers, | ||
dropout=opt.dropout), | ||
attn=GlobalAttention(opt.rnnSize), | ||
dropout=nn.Dropout(opt.dropout) | ||
) | ||
|
||
if opt.pre_word_vecs_enc is not None: | ||
pretrained = torch.load(opt.pre_word_vecs_dec) | ||
self.word_lut.weight.copy_(pretrained) | ||
|
||
self.has_features = feat_lut is not None | ||
if self.has_features: | ||
self.add_module('feat_lut', feat_lut) | ||
|
||
def forward(self, input, hidden): | ||
if self.has_features: | ||
word_emb = self.word_lut(input[0]) | ||
feat_emb = self.feat_lut(input[1]) | ||
emb = torch.cat([word_emb, feat_emb], 1) | ||
else: | ||
emb = self.word_lut(input) | ||
|
||
if self.input_feed: | ||
emb = torch.cat([emb, input_feed], 1) # 1 step | ||
|
||
outputs, next_hidden = self.rnn(input, hidden) | ||
|
||
attn = self.attn(outputs, context) # FIXME: per timestep? | ||
attn = self.dropout(attn) | ||
return attn, next_hidden |
Oops, something went wrong.