Skip to content

Commit

Permalink
completed ShiftNode::BackpropTo()
Browse files Browse the repository at this point in the history
  • Loading branch information
frankseide committed Jan 16, 2016
1 parent c186d49 commit 267be9e
Showing 2 changed files with 95 additions and 68 deletions.
159 changes: 93 additions & 66 deletions Source/ComputationNetworkLib/RecurrentNodes.h
Original file line number Diff line number Diff line change
@@ -243,11 +243,9 @@ namespace Microsoft { namespace MSR { namespace CNTK {
return TensorView<ElemType>(data, shape);
}

// determine FrameRange objects that describe the boundary frames of the sequence
// This version is for the case of iterating over time.
void DetermineBoundaryFrameRanges(const FrameRange & fr, const MBLayout::SequenceInfo & toSeqInfo, // range we operate on and current sequence under consideration
const ComputationNodeBasePtr & fromNode, FrameRange & frFrom, // boundary node
size_t T, FrameRange & frTo) const // ourselves (output)
// determine FrameRange objects that describe the boundary frames of the sequence for the output, for the case of iterating over time.
void DetermineBoundaryToFrameRange(const FrameRange & fr, const MBLayout::SequenceInfo & toSeqInfo, // range we operate on and current sequence under consideration
size_t T, FrameRange & frTo) const // ourselves (output)
{
// get FrameRange to write to in our output
frTo = fr.Sequence(toSeqInfo.s); // clip to this one sequence only
@@ -258,6 +256,17 @@ namespace Microsoft { namespace MSR { namespace CNTK {
LogicError("This code path has never been tested."); // remove this once we have
}
// frTo now describes the frame range that needs to be filled from the boundary node
}

// determine FrameRange objects that describe the boundary frames of the sequence
// This version is for the case of iterating over time.
void DetermineBoundaryFrameRanges(const FrameRange & fr, const MBLayout::SequenceInfo & toSeqInfo, // range we operate on and current sequence under consideration
const ComputationNodeBasePtr & fromNode, FrameRange & frFrom, // boundary node
size_t T, FrameRange & frTo) const // ourselves (output)
{
// get FrameRange to write to in our output
DetermineBoundaryToFrameRange(fr, toSeqInfo, T, frTo);
// frTo now describes the frame range that needs to be filled from the boundary node

// create a FrameRange for the boundary node to read from
// Boundary data is always a single frame.
@@ -299,8 +308,41 @@ namespace Microsoft { namespace MSR { namespace CNTK {
frFrom = frFrom.WithTimeStep(m_fromOffset > 0 ? 0 : fromT - 1).WithLayout(fromNode->GetMBLayout());
}

// perform op on all sequences that get boundary frames filled in a range that intersects with our output range
template<class OpFn>
void ForAllBoundaryIntersectingSequences(const FrameRange & fr, const SliceBounds & outSlice, size_t T, const OpFn & opFn)
{
if (fr.IsAllFrames() || GetMBLayout()->IsBeyondStartOrEnd(fr.WithTimeOffset(m_fromOffset))) // short-cut test whether there is anything to do
{
auto ts = outSlice.first[m_shiftDim];
auto te = outSlice.second[m_shiftDim];
// iterate over all sequences in this batch and handle all that overlap with the target region
for (auto toSeqInfo : GetMBLayout()->GetAllSequences())
{
// reduce to boundary frames
if (m_fromOffset < 0)
toSeqInfo.tEnd = min(toSeqInfo.tEnd, (size_t)max(toSeqInfo.tBegin - m_fromOffset, 0));
else
toSeqInfo.tBegin = max(toSeqInfo.tBegin, (int) (toSeqInfo.tEnd - m_fromOffset));

// if no overlap then skip
if (toSeqInfo.tEnd <= ts || toSeqInfo.tBegin >= te)
continue;

// clip sequence to [ts,te)
if (toSeqInfo.tBegin < ts)
toSeqInfo.tBegin = ts;
if (toSeqInfo.tEnd > te)
toSeqInfo.tEnd = te;

// action to perform
opFn(toSeqInfo);
}
}
}

// perform the copy (forward) or add (backprop) operation
void Propagate(const ComputationNodePtr & fromNode, TensorShape fromShape, const FrameRange & frFrom, TensorShape toShape, const FrameRange & frTo, bool isForward)
void Propagate(const ComputationNodePtr & fromNode, TensorShape fromShape, const FrameRange & frFrom, TensorShape toShape, const FrameRange & frTo, bool isForward, ElemType backwardSign)
{
auto fromSlice = TensorSliceWithMBLayoutFor(ToIntDims(fromShape), frFrom, fromNode->GetMBLayout());
auto toSlice = TensorSliceWithMBLayoutFor(ToIntDims(toShape), frTo, GetMBLayout());
@@ -318,7 +360,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
{
auto from = TensorView<ElemType>(fromNode->Gradient(), fromShape);
auto to = TensorView<ElemType>( Gradient(), toShape);
from.AddCopyOf(to);
from.AddCopyOf(to, backwardSign); // sign = -1 to subtract
}
}

@@ -341,29 +383,15 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// if iterating in time, we must pay attention to sequence boundaries inside the batch
if (isTimeIteration)
{
if (fr.IsAllFrames() || GetMBLayout()->IsBeyondStartOrEnd(fr.WithTimeOffset(m_fromOffset))) // short-cut test whether there is anything to do
ForAllBoundaryIntersectingSequences(fr, outSliceLogical, T, [&](const MBLayout::SequenceInfo & toSeqInfo)
{
auto ts = outSliceLogical.first[m_shiftDim];
auto te = outSliceLogical.second[m_shiftDim];
// iterate over all sequences in this batch and handle all that overlap with the target region
for (auto toSeqInfo : GetMBLayout()->GetAllSequences())
{
// clip sequence to [ts,te)
if (toSeqInfo.tEnd <= ts || toSeqInfo.tBegin >= te) // no overlap--skip
continue;

// get bounds
if (toSeqInfo.tBegin < ts)
toSeqInfo.tBegin = ts;
if (toSeqInfo.tEnd > te)
toSeqInfo.tEnd = te;
FrameRange frFrom, frTo;
DetermineBoundaryFrameRanges(fr, toSeqInfo, fromNode, frFrom, T, frTo);

// copy/backprop
Propagate(fromNode, fromShape, frFrom, outShape, frTo, isForward);
}
}
// determine FrameRanges for from and to
FrameRange frFrom, frTo;
DetermineBoundaryFrameRanges(fr, toSeqInfo, fromNode, frFrom, T, frTo);

// copy/backprop
Propagate(fromNode, fromShape, frFrom, outShape, frTo, isForward, +1);
});
}
// iterating over fixed sample-shape dimensions
else if (!isTimeIteration && (inSliceLogical.first[m_shiftDim] < 0 || inSliceLogical.second[m_shiftDim] >= T))
@@ -374,7 +402,7 @@ namespace Microsoft { namespace MSR { namespace CNTK {
DetermineBoundaryFrameRanges(fr, fromNode, fromT, frFrom, T, frTo);

// copy/backprop
Propagate(fromNode, fromShape, frFrom, outShape, frTo, isForward);
Propagate(fromNode, fromShape, frFrom, outShape, frTo, isForward, +1);
LogicError("This code path has never been tested."); // remove this once we have
}
}
@@ -383,6 +411,8 @@ namespace Microsoft { namespace MSR { namespace CNTK {

virtual void ForwardProp(const FrameRange & fr) override
{
//for (size_t xx = 0; xx < 3; xx++) // for testing the strange slow-down
{
if (fr.GetIterationDimension() != m_shiftDimParam)
LogicError("ShiftNode::ForwardProp(): FrameRange not iterating over user-specified dimension.");

@@ -431,20 +461,13 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// The above may already have written (wrong) values in there, or not written anything at all yet.

PropagateBoundaryFrames(fr, rank, inSliceLogical, outShape, outSliceLogical, /*isForward=*/true);
}
}

virtual void /*ComputationNode::*/BackpropTo(const size_t inputIndex, const FrameRange & fr) override
{
// Note: To allow for bulk gradient computation, this function will clear out any gradient that should not be propagated.
// We do that directly to our incoming output gradient, similar to gap masking but in the non-gap areas.
// This is OK because we own this, and it is never read except by ourselves after we get in here.
// NAW, NOT WORKING! We need to propagate those gaps into the boundary node:
// - Backprop will be called as a bulk op (since outside the recurrent loop), but only few frames must be propagated, affecting one column only.
// - Backprop into input is still called frame by frame.
// - If outside recurrent loop, backprop into input is done in bulk, and it is unlikely that we also have a boundary node to propagate into.
// - This will be fixed later; for now, we simply don't support boundary nodes.
// - Or we first bulk-add, and then subtract it out again...

// if (!fr.IsAllFrames()) // for measuring speed
// return;
TensorShape inShape, outShape; // expanded tensor shapes of input and output
SliceBounds inSliceLogical, outSliceLogical; // the logical ranges to shift
size_t rank = DetermineElementwiseTensorRank();
@@ -461,43 +484,47 @@ namespace Microsoft { namespace MSR { namespace CNTK {
// propagate into input
else if (inputIndex == 0)
{
// STEP 1a: backprop all we got, including invalid ones. We later subtract them again. Not so nice actually, but saves memory.
// get the logical ranges we want to shift
// now copy the two stripes--one that is main-to-main, and one that pulls in data from previous state (truncated BPTT only)
// This correctly handles if input is a tensor with strides. This is currently not the case, but may be if we support in-place.

// STEP 1a: backprop all we got, including invalid ones. Inner boundary frames that we shouldn't have propagated, we later subtract again.
SliceBounds inSliceMain, outSliceMain; // main-to-main
SliceBounds inSliceState, outSliceState; // from state
SliceBounds inSliceState, outSliceState; // from state --dummy
auto T = outShape[m_shiftDim]; // upper bound of iteration dimension
PartitionSlices(inSliceLogical, outSliceLogical, T, inSliceMain, outSliceMain, inSliceState, outSliceState);

if (inSliceMain.second[m_shiftDim] > inSliceMain.first[m_shiftDim])
{
Input(0)->MaskMissingGradientColumnsToZero(fr); // zero out gaps, which will leak (note: we really only need to zero out gaps close enough to boundaries)
auto from = DataTensorFor(Input(0)->Gradient(), inShape, inSliceMain);
auto from = DataTensorFor(Input(0)->Gradient(), inShape, inSliceMain);
auto to = DataTensorFor( Gradient(), outShape, outSliceMain);
from.AddCopyOf(to);
}
// We have now propagated anything from within the logical bounds.
// In the case of packing we will have propagated incorrectly propagated across boundaries.
// Maybe a better way would be to copy around the frames that we should not copy.

// STEP 1b: fix up the frames that we incorrectly propagated
// Only happens for time iterations, only at inner boundaries.
bool isTimeIteration = m_shiftDim >= rank;
if (isTimeIteration)
{
if (fr.IsAllFrames() || GetMBLayout()->IsBeyondStartOrEnd(fr.WithTimeOffset(m_fromOffset))) // short-cut test whether there is anything to do

// We have now propagated anything from within the logical bounds.
// In the case of packing we will have propagated incorrectly propagated across boundaries.
// We will now subtract the incorrectly leaked gradient frames out again.
// (We also propagated from gaps, but those have already been reset to 0, so those require no correction.)
// E.g. shifting by -1
// |X X X X X|Y Y Y|G G G output gradient
// |X X X X|Y Y Y|G G G Input(0) gradient
// ^ incorrect leak: must subtract out
// ^ ^ no need to correct since already 0
// |<----------------->| output gradient range we must consider = outSliceMain
// (Maybe a better way would be to copy around the frames that we should not copy.)

// STEP 1b: fix up the frames that we incorrectly propagated
// Only happens for time iterations, only at inner boundaries.
bool isTimeIteration = m_shiftDim >= rank;
if (isTimeIteration)
{
auto ts = outSliceLogical.first[m_shiftDim];
auto te = outSliceLogical.second[m_shiftDim];
// iterate over all sequences in this batch and handle all that overlap with the target region
for (auto toSeqInfo : GetMBLayout()->GetAllSequences())
ForAllBoundaryIntersectingSequences(fr, outSliceMain/*already clipped*/, T, [&](const MBLayout::SequenceInfo & toSeqInfo)
{
if (toSeqInfo.tEnd <= ts || toSeqInfo.tBegin >= te) // no overlap--skip
continue;
sin(1);
}
// determine FrameRanges for from and to
FrameRange frTo;
DetermineBoundaryToFrameRange(fr, toSeqInfo, T, frTo);
FrameRange frFrom = frTo.WithTimeOffset(m_fromOffset);
assert((int)frFrom.timeIdxInSeq + frFrom.m_timeOffset >= 0 && (int)frFrom.timeIdxInSeq + frFrom.m_timeOffset + (int)frFrom.m_timeRange <= (int)T);

// copy/backprop
Propagate(shared_from_this(), inShape, frFrom, outShape, frTo, /*isForward=*/false, -1/*subtract*/);
});
}
}
}
4 changes: 2 additions & 2 deletions Tests/EndToEndTests/Speech/LSTM/cntk.config
Original file line number Diff line number Diff line change
@@ -68,8 +68,8 @@ speechTrain = [
// LSTM cell
# TODO: This is temporary test code for the new ShiftNode (until we switch PastValue() itself over)
PastValueShift(dimDummy, input) = Shift(input, /*fromOffsets=*/-1, /*boundaryValue=*/Constant(0.1), dim=-1)
#PastValue1 = PastValue
PastValue1 = PastValueShift
PastValue1 = PastValue
#PastValue1 = PastValueShift
dh = PastValue1(outputDim, output); // hidden state(t-1)
dc = PastValue1(cellDim, ct); // cell(t-1)

0 comments on commit 267be9e

Please sign in to comment.