Skip to content

Commit

Permalink
updated beam decoder to new structure
Browse files Browse the repository at this point in the history
  • Loading branch information
frankseide committed Apr 9, 2016
1 parent 239633b commit b0bca78
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 53 deletions.
1 change: 1 addition & 0 deletions CNTK.sln
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Config", "Config", "{850008
ProjectSection(SolutionItems) = preProject
Examples\Text\PennTreebank\Config\rnn.cntk = Examples\Text\PennTreebank\Config\rnn.cntk
Examples\Text\PennTreebank\Config\S2SAutoEncoder.cntk = Examples\Text\PennTreebank\Config\S2SAutoEncoder.cntk
Examples\Text\PennTreebank\Config\S2SLib.bs = Examples\Text\PennTreebank\Config\S2SLib.bs
EndProjectSection
EndProject
Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "SLU", "SLU", "{E6DC3B7D-303D-4A54-B040-D8DCF8C56E17}"
Expand Down
87 changes: 37 additions & 50 deletions Examples/Text/PennTreebank/Config/S2SAutoEncoder.cntk
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ ExpRootDir = "$RunRootDir$"
#ExpId = _run

deviceId = 0
ExpId = 68-$deviceId$-s2sae-bigmodel
#ExpId = 68-$deviceId$-s2sae-bigmodel
ExpId = 67-1-s2sae-bigmodel # for decoding a different model

# directories
ExpDir = "$ExpRootDir$/$ExpId$"
Expand All @@ -42,7 +43,7 @@ command = writeWordAndClassInfo:train:test:write
precision = "float"
traceLevel = 1
modelPath = "$ModelDir$/S2SAutoEncoder.dnn"
decodeModelPath = "$modelPath$.13" # epoch to decode. Has best CV WER
decodeModelPath = "$modelPath$" # epoch to decode. Has best CV WER

confVocabSize = 10000
confClassSize = 50
Expand All @@ -55,6 +56,7 @@ validFile = "ptb.valid.txt"
#validFile = "ptb.small.valid.txt"
testFile = "ptb.test.txt"
#testFile = "ptb.test.txt-econ1"
#testFile = "ptb.small.train.txt" # test on train, to see whether model makes sense at all

#######################################
# network definition #
Expand All @@ -63,43 +65,10 @@ testFile = "ptb.test.txt"
BrainScriptNetworkBuilder = (new ComputationNetwork [

# TODO: move this somewhere shared
Trace (node, say='', logFrequency=traceFrequency, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[], tag='') = new ComputationNode [
operation = 'Trace' ; inputs = node
]

formatDense = [
type = "real"
transpose = false
precisionFormat = ".4"
]
formatOneHot = [
type = "category"
transpose = false
labelMappingFile = "$ModelDir$/vocab.wl"
]
formatSparse = [
type = "sparse"
transpose = false
labelMappingFile = "$ModelDir$/vocab.wl"
]
enableTracing = true
traceFrequency = 1000
TraceState (h, what) =
if enableTracing
then Transpose (Trace (Transpose (h), say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, onlyUpToRow=beamDepth*beamDepth, onlyUpToT=3, format=formatDense))
else h
TraceDense (h, what) =
if enableTracing
then Trace (h, say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, onlyUpToRow=beamDepth*beamDepth, onlyUpToT=25, format=formatDense)
else h
TraceOneHot (h, what) =
if enableTracing
then Trace (h, say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, /*onlyUpToRow=beamDepth*beamDepth, onlyUpToT=15,*/ format=formatOneHot)
else h
TraceSparse (h, what) =
if enableTracing
then Trace (h, say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, /*onlyUpToRow=beamDepth*beamDepth, onlyUpToT=3,*/ format=formatSparse)
else h
tracingLabelMappingFile = "$ModelDir$/vocab.wl"
include "S2SLib.bs"
beamDepth=3 // for above Trace macros only

# import general config options from outside config values
Expand Down Expand Up @@ -537,6 +506,11 @@ write = [
# We need to make a change:
BrainScriptNetworkBuilder = ([

enableTracing = true
traceFrequency = 1000
tracingLabelMappingFile = "$ModelDir$/vocab.wl"
include "S2SLib.bs"

beamDepth = 3 // 0=predict; 1=greedy; >1=beam

# import some names
Expand Down Expand Up @@ -575,8 +549,7 @@ write = [
logP = LogSoftmax (model.z)

offset = Constant (10000)
top1a = Hardmax (logP) .* (logP + offset)/*for tracing*/
top1b = top1a
top1b = Hardmax (logP) .* (logP + offset)/*for tracing*/
top1 = TraceSparse (top1b, 'logP') # TODO: get the accumulated logP out, it's a little more involved

topN = 10
Expand All @@ -591,10 +564,21 @@ write = [
]

# replace old decoderFeedback node by newDecoderFeedback
EmbedLabels (x) = TransposeTimes (modelAsTrained.labelsEmbedded.TransposeTimesArgs[0], x)
decoderFeedback = EmbedLabels (Hardmax (modelAsTrained.z)) # in training, this is decoderFeedback = labelsEmbedded

decoderFeedback = modelAsTrained.decoderOutputEmbedded # in training, this is decoderFeedback = labelsEmbedded
sentenceStartEmbedded = Boolean.If (Loop.IsFirst (decoderFeedback), modelAsTrained.inputEmbedded, Previous (sentenceStartEmbedded)) # enforces no leaking of labels
delayedDecoderFeedback = Boolean.If (Loop.IsFirst (decoderFeedback), sentenceStartEmbedded, Loop.Previous (decoderFeedback)) # same expression as in training
# TODO: fold this in
PreviousOrDefault1 (x, defaultValue=Constant (0)) = # a delay node with initial value --TODO: merge the two, then do in C++
[
flags = BS.Loop.IsFirst (defaultValue/*x*/)
out = BS.Boolean.If (flags,
/*then*/ defaultValue,
/*else*/ Previous (x))
].out

labelSentenceStartEmbeddedScattered = TraceDense (BS.Sequences.Scatter (BS.Loop.IsFirst (modelAsTrained.labelsEmbedded), modelAsTrained.labelSentenceStartEmbedded), 'sest')

delayedDecoderFeedback = TraceDense (/*Loop.*/PreviousOrDefault1 (defaultValue=labelSentenceStartEmbeddedScattered, TraceDense (decoderFeedback, 'lemb')) , 'prev lemb')

greedyDecodingModel = BS.Network.Edit (modelAsTrained,
BS.Network.Editing.ReplaceLinksToNode (modelAsTrained.delayedDecoderFeedback, delayedDecoderFeedback),
Expand Down Expand Up @@ -630,17 +614,19 @@ write = [
# - traceback is a right-to-left recurrence
# - output best hypo conditioned on the path (it is already known)

propagationEdits[i:0..2] = // TODO: implement and use { } syntax
propagationEdits[i:0..6] = // TODO: implement and use { } syntax
if i == 0 then (node => if node.name == 'decoder[0].prevState.h' then TraceState (Previous (PropagateTopN (node.PastValueArgs[0])), 'propagated') else node) # inject reshuffling of hypotheses
else if i == 1 then (node => if node.name == 'decoder[0].prevState.c' then TraceState (Previous (PropagateTopN (node.PastValueArgs[0])), 'propagated') else node)
else if i == 2 then (node => if node.name == 'decoder[1].prevState.h' then TraceState (Previous (PropagateTopN (node.PastValueArgs[0])), 'propagated') else node) # inject reshuffling of hypotheses
else if i == 3 then (node => if node.name == 'decoder[1].prevState.c' then TraceState (Previous (PropagateTopN (node.PastValueArgs[0])), 'propagated') else node)
else if i == 4 then (node => if node.name == 'decoder[2].prevState.h' then TraceState (Previous (PropagateTopN (node.PastValueArgs[0])), 'propagated') else node) # inject reshuffling of hypotheses
else if i == 5 then (node => if node.name == 'decoder[2].prevState.c' then TraceState (Previous (PropagateTopN (node.PastValueArgs[0])), 'propagated') else node)
else BS.Network.Editing.ReplaceLinksToNode (modelAsTrained.delayedDecoderFeedback, delayedDecoderFeedback)

# decoderFeedback must be updated to take actual decoder output

Elabel = modelAsTrained.decoderOutputEmbedded.TransposeTimesArgs[0]
decoderFeedback = TraceState (TransposeTimes (Elabel, TraceSparse (topWords, 'topWords')), 'feedback')

delayedDecoderFeedback = Boolean.If (Loop.IsFirst (decoderFeedback), sentenceStartEmbedded, Loop.Previous (decoderFeedback))
decoderFeedback = TraceState (EmbedLabels (TraceSparse (topWords, 'topWords')), 'feedback')
delayedDecoderFeedback = Boolean.If (Loop.IsFirst (labelSentenceStartEmbeddedScattered), labelSentenceStartEmbeddedScattered, Loop.Previous (decoderFeedback))

m2 = BS.Network.Edit (modelAsTrained,
propagationEdits,
Expand Down Expand Up @@ -679,7 +665,7 @@ write = [
LOGZERO = -1e30
initialPathScores = FirstAndOther (0, LOGZERO, beamDepth, axis = 2) # row vector: [ 0, -INF, -INF, -INF, ... ]

expandedPathScores = logLLs + PreviousOrDefault (PropagateTopN (pathScores), initialPathScores) # [V x Dprev] un-normalized log (P(w|hist) * P(hist)) for all top D hypotheses
expandedPathScores = logLLs + Boolean.If (Loop.IsFirst (pathScores), initialPathScores, Previous (PropagateTopN (pathScores))) # [V x Dprev] un-normalized log (P(w|hist) * P(hist)) for all top D hypotheses
# ^^ path expansion, [V x 1] + [1 x D] -> [V x D]

tokenSet = TraceSparse (GetTopNTensor (beamDepth, expandedPathScores), 'tokenSet') # [V x Dprev] -> [V x Dprev x Dnew]
Expand Down Expand Up @@ -746,7 +732,8 @@ write = [
# previous states: multiply wth respective backPointers matrix
# -> hyp index for every time step
# then finally use that to select the actual output TODO: That's a sample-wise matrix product between two sequences!!!
traceback = TraceDense (NextOrDefault (backPointers * traceback, finalHyp), 'traceback') # [D] one-hot, multiplying backPointers from the left will select another one-hot row of backPointers
# TODO: condition must be 1-dim, not 2-dim tensor, so we use labelSentenceStartEmbeddedScattered instead of backPointers
traceback = TraceDense (Boolean.If (Loop.IsLast (labelSentenceStartEmbeddedScattered/*backPointers*/), finalHyp, Loop.Next (backPointers * traceback)), 'traceback') # [D] one-hot, multiplying backPointers from the left will select another one-hot row of backPointers
# +-+
# |0|
# |1| means at this time step, hyp[1] was the best globally
Expand All @@ -758,7 +745,7 @@ write = [
# This is the one to output (top sentence-level hypothesis after traceback).
decode = [
hyp = Times (tokenSet, traceback, outputRank = 2) # [V x Dprev] 2D one-hot
out = TraceOneHot (hyp * ConstantTensor (1, beamDepth), 'out') # reduces over Dprev -> 1D one-hot
out = TraceOneHot (hyp * ConstantTensor (1, (beamDepth)), 'out') # reduces over Dprev -> 1D one-hot
].out
# traceback : [Dnew]
# tokenSet : [V x Dprev x Dnew]
Expand Down
39 changes: 39 additions & 0 deletions Examples/Text/PennTreebank/Config/S2SLib.bs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# TODO: must sort this out. For now, this is just shared stuff between training and decoding.

# these depend on beamDepth parameter for now, fix this
TraceState (h, what) =
if enableTracing
then Transpose (Trace (Transpose (h), say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, onlyUpToRow=beamDepth*beamDepth, onlyUpToT=3, format=formatDense))
else h
TraceDense (h, what) =
if enableTracing
then Trace (h, say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, onlyUpToRow=beamDepth*beamDepth, onlyUpToT=25, format=formatDense)
else h
TraceOneHot (h, what) =
if enableTracing
then Trace (h, say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, /*onlyUpToRow=beamDepth*beamDepth, onlyUpToT=15,*/ format=formatOneHot)
else h
TraceSparse (h, what) =
if enableTracing
then Trace (h, say=what, logFirst=10, logFrequency=traceFrequency, logGradientToo=false, /*onlyUpToRow=beamDepth*beamDepth, onlyUpToT=3,*/ format=formatSparse)
else h

Trace (node, say='', logFrequency=traceFrequency, logFirst=10, logGradientToo=false, onlyUpToRow=100000000, onlyUpToT=100000000, format=[], tag='') = new ComputationNode [
operation = 'Trace' ; inputs = node
]

formatDense = [
type = "real"
transpose = false
precisionFormat = ".4"
]
formatOneHot = [
type = "category"
transpose = false
labelMappingFile = tracingLabelMappingFile
]
formatSparse = [
type = "sparse"
transpose = false
labelMappingFile = tracingLabelMappingFile
]
7 changes: 4 additions & 3 deletions Source/ComputationNetworkLib/ComputationNetworkEvaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,9 @@ ComputationNetwork::PARTraversalFlowControlNode::PARTraversalFlowControlNode(con
for (auto& node : m_nestedNodes)
{
if (node->GetMBLayout() != GetMBLayout())
LogicError("Evaluate: all nodes inside a recurrent loop must have a layout that is identical; mismatch found for nodes '%ls' vs. '%ls'",
node->NodeName().c_str(), m_nestedNodes[0]->NodeName().c_str());
LogicError("Evaluate: All nodes inside a recurrent loop must have a layout that is identical; mismatch found for nodes '%ls' (%ls) vs. '%ls' (%ls)",
node ->NodeName().c_str(), node ->GetMBLayoutAxisString().c_str(),
m_nestedNodes[0]->NodeName().c_str(), m_nestedNodes[0]->GetMBLayoutAxisString().c_str());
}

// tell all that loop is about to commence
Expand Down Expand Up @@ -672,7 +673,7 @@ size_t ComputationNetwork::ValidateNodes(list<ComputationNodeBasePtr> nodes, boo
{
unchanged = !ValidateNode(node, isFinalValidationPass);
string updatedPrototype = node->FormatOperationPrototype("");
#if 1 // print prototype in final validation pass
#if 0 // print prototype in final validation pass. Problematic for tracking down validation errors in loops.
unchanged;
if (isFinalValidationPass)
#else // print prototype upon every change (useful for debugging)
Expand Down

0 comments on commit b0bca78

Please sign in to comment.