Skip to content

Commit

Permalink
added node tracing to 'eval' and 'write', and added a sparse option
Browse files Browse the repository at this point in the history
  • Loading branch information
frankseide committed Mar 29, 2016
1 parent 4a63349 commit 539f6e6
Show file tree
Hide file tree
Showing 14 changed files with 165 additions and 67 deletions.
15 changes: 13 additions & 2 deletions Examples/Text/PennTreebank/Config/S2SAutoEncoder.cntk
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ write = [
SkipFirst (x) = BS.Sequences.Skip (1, x)

# replace every reference of these by PropagateTopN(of these)
EditToPropagateTopN (name) = (node => if node.name == name then PropagateTopN (node) else node)
EditToPropagateTopN (name) = (node => if node.name == name then /*PropagateTopN*/ (node) else node)
propagationEdits[i:0..3] = // TODO: implement and use { } syntax
if i == 0 then EditToPropagateTopN ('decoder[0].prevState.h')
else if i == 1 then EditToPropagateTopN ('decoder[0].prevState.c')
Expand Down Expand Up @@ -721,13 +721,24 @@ write = [
outputPath = "-" # "-" will write to stdout; useful for debugging
#outputNodeNames = z1.out:labels1 # when processing one sentence per minibatch, this is the sentence posterior
outputNodeNames = network.beamDecodingModel.z1.out:labels1 # when processing one sentence per minibatch, this is the sentence posterior
#outputNodeNames = labels1:network.beamDecodingModel.expandedPathScores
#outputNodeNames = network.beamDecodingModel.pathScores:network.beamDecodingModel.traceback
# network.beamDecodingModel.tokenSetScores
# network.beamDecodingModel.pathScores
# network.beamDecodingModel.traceback
# network.beamDecodingModel.expandedPathScores

format = [
type = "category"
type = "sparse"
transpose = false
labelMappingFile = "$ModelDir$/vocab.wl"
sequenceEpilogue = "\t// %s\n"
]

traceNodeNamesReal = network.beamDecodingModel.pathScores:network.beamDecodingModel.tokenSetScores:network.beamDecodingModel.expandedPathScores:network.beamDecodingModel.backPointers:network.beamDecodingModel.initialPathScores.out.out.input
#traceNodeNamesCategory = network.beamDecodingModel.tokenSetScores
traceNodeNamesSparse = network.beamDecodingModel.tokenSetScores:network.beamDecodingModel.backPointers:decoderOutputEmbedded.x

minibatchSize = 8192 # choose this to be big enough for the longest sentence
# need to be small since models are updated for each minibatch
traceLevel = 1
Expand Down
18 changes: 14 additions & 4 deletions Source/ActionsLib/EvalActions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ static void DoEvalBase(const ConfigParameters& config, IDataReader& reader)

auto net = ComputationNetwork::CreateFromFile<ElemType>(deviceId, modelPath);

// set tracing flags
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));

SimpleEvaluator<ElemType> eval(net, MPIWrapper::GetInstance(), numMBsToShowResult, traceLevel, maxSamplesInRAM, numSubminiBatches);
eval.Evaluate(&reader, evalNodeNamesVector, mbSize[0], epochSize);
}
Expand Down Expand Up @@ -225,6 +230,11 @@ void DoWriteOutput(const ConfigParameters& config)

let net = GetModelFromConfig<ConfigParameters, ElemType>(config, outputNodeNamesVector);

// set tracing flags
net->EnableNodeTracing(config(L"traceNodeNamesReal", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesCategory", ConfigParameters::Array(stringargvector())),
config(L"traceNodeNamesSparse", ConfigParameters::Array(stringargvector())));

SimpleOutputWriter<ElemType> writer(net, 1);

if (config.Exists("writer"))
Expand All @@ -246,11 +256,11 @@ void DoWriteOutput(const ConfigParameters& config)
if (formatConfig.ExistsCurrent("type")) // do not inherit 'type' from outer block
{
string type = formatConfig(L"type");
if (type == "real") formattingOptions.isCategoryLabel = false;
if (type == "real") ; // default
else if (type == "category") formattingOptions.isCategoryLabel = true;
else InvalidArgument("write: type must be 'real' or 'category'");
if (formattingOptions.isCategoryLabel)
formattingOptions.labelMappingFile = (wstring)formatConfig(L"labelMappingFile", L"");
else if (type == "sparse") formattingOptions.isSparse = true;
else InvalidArgument("write: type must be 'real', 'category', or 'sparse'");
formattingOptions.labelMappingFile = (wstring)formatConfig(L"labelMappingFile", L"");
}
formattingOptions.transpose = formatConfig(L"transpose", formattingOptions.transpose);
formattingOptions.prologue = formatConfig(L"prologue", formattingOptions.prologue);
Expand Down
3 changes: 2 additions & 1 deletion Source/CNTK/BrainScript/BrainScriptEvaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1023,4 +1023,5 @@ static ScriptableObjects::ConfigurableRuntimeTypeRegister::Add<Debug>
// - macro arg expressions get their path assigned when their thunk is created, the thunk remembers it
// - however, really, the thunk should get the expression path from the context it is executed in, not the context it was created in
// - maybe there is some clever scheme of overwriting when a result comes back? E.g. we retrieve a value but its name is not right, can we patch it up? Very tricky to find the right rules/conditions
} } } // namespaces

}}} // namespaces
8 changes: 6 additions & 2 deletions Source/CNTK/BrainScript/CNTKCoreLib/CNTK.core.bs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

Print(value, format='') = new PrintAction [ what = value /*; how = format*/ ]
Debug(value, say = '', enabled = true) = new Debug [ /*macro arg values*/ ]
Fail(what) = new FailAction [ /*what*/ ]
Format(value, format) = new StringFunction [ what = 'Format' ; arg = value ; how = format ]
Replace(s, from, to) = new StringFunction [ what = 'Replace' ; arg = s ; replacewhat = from ; withwhat = to ]
Substr(s, begin, num) = new StringFunction [ what = 'Substr' ; arg = s ; pos = begin ; chars = num ]
Expand Down Expand Up @@ -43,14 +44,17 @@ Constant(val, rows = 1, cols = 1, tag='') = Parameter(rows, cols, learningRateMu
PastValue(dims, input, timeStep = 1, defaultHiddenActivation = 0.1, tag='') = new ComputationNode [ operation = 'PastValue' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]
FutureValue(dims, input, timeStep = 1, defaultHiddenActivation = 0.1, tag='') = new ComputationNode [ operation = 'FutureValue' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]
Shift(input, fromOffset, boundaryValue, boundaryMode=-1/*context*/, dim=-1, tag='') = new ComputationNode [ operation = 'Shift' ; inputs = (input : boundaryValue) /*plus the function args*/ ]
RowSlice(beginIndex, numRows, input, tag='') = new ComputationNode [ operation = 'Slice' ; endIndex = beginIndex + numRows; axis = 1/*row*/; inputs = input /*plus the function args*/ ]
RowSlice(beginIndex, numRows, input, tag='') = Slice(beginIndex, beginIndex + numRows, input, axis = 1)
RowRepeat(input, numRepeats, tag='') = new ComputationNode [ operation = 'RowRepeat' ; inputs = input /*plus the function args*/ ]
RowStack(inputs, tag='') = new ComputationNode [ operation = 'RowStack' /*plus the function args*/ ]
Reshape(input, numRows, imageWidth = 0, imageHeight = 0, imageChannels = 0, tag='') = new ComputationNode [ operation = 'LegacyReshape' ; inputs = input /*plus the function args*/ ]
NewReshape(input, dims, beginAxis=0, endAxis=0, tag='') = new ComputationNode [ operation = 'Reshape' ; inputs = input ; shape = new TensorShape [ /*dims*/ ] /*plus the function args*/ ]
ReshapeDimension(x, axis, tensorShape) = NewReshape(x, tensorShape, beginAxis=axis, endAxis=axis + 1)
FlattenDimensions(x, axis, num) = NewReshape(x, 0, beginAxis=axis, endAxis=axis + num)
#Slice(beginIndex, endIndex, input, axis=1, tag='') = new ComputationNode [ operation = 'Slice' ; endIndex = beginIndex + numRows; inputs = input /*plus the function args*/ ]
Slice(beginIndex, endIndex, input, axis=1, tag='') =
if axis < 1
then Fail('Slice does not yet implement slicing the time axis.') // TODO: implement using Gather()
else new ComputationNode [ operation = 'Slice' ; inputs = input /*plus the function args*/ ]
SplitDimension(x, axis, N) = ReshapeDimension(x, axis, 0:N)
TransposeDimensions(input, axis1, axis2, tag='') = new ComputationNode [ operation = 'TransposeDimensions' ; inputs = input /*plus the function args*/ ]
Transpose(x) = TransposeDimensions(x, 1, 2)
Expand Down
6 changes: 3 additions & 3 deletions Source/Common/Include/Basics.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ __declspec_noreturn static inline void ThrowFormatted(const char* format, ...)
va_list args;
va_start(args, format);

char buffer[1024] = { 0 }; // initialize in case vsnprintf() does a half-assed job such as a failing character conversion
int written = vsnprintf(buffer, _countof(buffer) - 1, format, args); // -1 because vsnprintf() does not always write a 0-terminator, although the MSDN documentation states so
char buffer[1024] = { 0 }; // Note: pre-VS2015 vsnprintf() is not standards-compliant and may not add a terminator
int written = vsnprintf(buffer, _countof(buffer) - 1, format, args); // -1 because pre-VS2015 vsnprintf() does not always write a 0-terminator
// TODO: In case of EILSEQ error, choose between just outputting the raw format itself vs. continuing the half-completed buffer
//if (written < 0) // an invalid wide-string conversion may lead to EILSEQ
// strncpy(buffer, format, _countof(buffer)
UNUSED(written); // vsnprintf() returns -1 in case of overflow, instead of the #characters written as claimed in the MSDN documentation states so
UNUSED(written); // pre-VS2015 vsnprintf() returns -1 in case of overflow, instead of the #characters written
if (strlen(buffer)/*written*/ >= (int)_countof(buffer) - 2)
sprintf(buffer + _countof(buffer) - 4, "...");
#ifdef _DEBUG // print this to log, so we can see what the error is before throwing
Expand Down
13 changes: 13 additions & 0 deletions Source/ComputationNetworkLib/ComputationNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,19 @@ class ComputationNetwork :
// diagnostics
// -----------------------------------------------------------------------

// call EnableNodeTracing() on the given nodes for real, category, and sparse printing
void EnableNodeTracing(const std::vector<std::wstring>& traceNodeNamesReal,
const std::vector<std::wstring>& traceNodeNamesCategory,
const std::vector<std::wstring>& traceNodeNamesSparse)
{
for (const auto& name : traceNodeNamesReal)
GetNodeFromName(name)->EnableNodeTracing(/*asReal=*/true, /*asCategoryLabel=*/false, /*asSparse=*/false);
for (const auto& name : traceNodeNamesCategory)
GetNodeFromName(name)->EnableNodeTracing(/*asReal=*/false, /*asCategoryLabel=*/true, /*asSparse=*/false);
for (const auto& name : traceNodeNamesSparse)
GetNodeFromName(name)->EnableNodeTracing(/*asReal=*/false, /*asCategoryLabel=*/false, /*asSparse=*/true);
}

// if node name is not found, dump all nodes
// otherwise dump just that node
// This function is called from MEL, i.e. must be prepared to operate on an uncompiled network (only m_nameToNodeMap is valid).
Expand Down
2 changes: 2 additions & 0 deletions Source/ComputationNetworkLib/ComputationNetworkAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ void ComputationNetwork::FormRecurrentLoops(const ComputationNodeBasePtr& rootNo
// - the loop itself
// - consumers of the loop
// (consumers of the loop can never be inputs, otherwise they would be part of the loop)
// This should be done by the SEQ constructor. I don't dare at present because some other code (e.g. memshare) relies on this ordering.
// - each loop is sorted inside
// - break loop graph into sub-graphs between delay nodes
// This is necessary.

// --- BEGIN reorder process --TODO: eliminate this entire chunk of code; don't update EvalOrder; instead, do it only when constructing the outer PAR node

Expand Down
Loading

0 comments on commit 539f6e6

Please sign in to comment.