Skip to content

Commit 1644312

Browse files
authoredMar 17, 2018
[src] Add a nnet3 optimization that tries to replace commands ending in Multi with other commands. (kaldi-asr#2229)
1 parent 8ab6e53 commit 1644312

6 files changed

+379
-29
lines changed
 

‎src/nnet3/nnet-computation.h

+9-4
Original file line numberDiff line numberDiff line change
@@ -395,12 +395,17 @@ struct NnetComputation {
395395
// These are owned here.
396396
std::vector<PrecomputedIndexesInfo> component_precomputed_indexes;
397397

398-
// used in kAddRows, kAddToRows, kCopyRows, kCopyToRows. contains row-indexes.
398+
// Used in commands kAddRows, kAddToRows, kCopyRows, which
399+
// contain indexes into this data-member.
400+
// Each vector<int32> is a vector of row-indexes (with -1 usually treated as
401+
// a special case meaning "don't do anything for this row" for add
402+
// commands, or "use zero" for copy commands.
399403
std::vector<std::vector<int32> > indexes;
400404

401-
// used kAddRowsMulti, kAddToRowsMulti, kCopyRowsMulti, kCopyToRowsMulti.
402-
// contains pairs (sub-matrix index, row index)- or (-1,-1) meaning don't
403-
// do anything for this row.
405+
// Used in commands kAddRowsMulti, kAddToRowsMulti, kCopyRowsMulti and
406+
// kCopyToRowsMulti. Contains pairs (sub-matrix index, row index)- or the
407+
// special pair (-1,-1) meaning "don't do anything for this row" for add
408+
// commands, or "use zero" for copy commands.
404409
std::vector<std::vector<std::pair<int32,int32> > > indexes_multi;
405410

406411

‎src/nnet3/nnet-optimize-utils.cc

+322
Original file line numberDiff line numberDiff line change
@@ -2576,6 +2576,328 @@ bool SnipRowOps(NnetComputation *computation) {
25762576

25772577

25782578

2579+
// This class implements the internals of the function SplitRowOps() which is
2580+
// declared in nnet-optimize-utils.h.
2581+
class RowOpsSplitter {
2582+
public:
2583+
RowOpsSplitter(NnetComputation *computation): computation_(computation) { }
2584+
2585+
// Attempts to perform the optimization. Returns true if it made any change
2586+
// to the computation.
2587+
bool Split() {
2588+
return SplitIndexes() && SplitCommands();
2589+
}
2590+
2591+
private:
2592+
2593+
// This function sets up split_info_, which describes how we can split up
2594+
// the vectors that are elements of computation_->indexes_multi.
2595+
// It will return true if it successfully split at least one of those
2596+
// vectors, and false otherwise.
2597+
bool SplitIndexes();
2598+
2599+
// This function modifies the commands in the computation. It returns
2600+
// true if it made any change.
2601+
bool SplitCommands();
2602+
2603+
2604+
// This function attempts to optimize the command in
2605+
// computation_->commands[command_index]. It returns true if it made any
2606+
// change. If we are going to have to insert an extra command into the
2607+
// computation, this function will append an element to new_commands_.
2608+
bool SplitCommand(int32 command_index);
2609+
2610+
// Below, define a multi-index as an element of NnetComputation::indexes_multi,
2611+
// for example,
2612+
// const std::vector<std::pair<int32,int32> > &multi_index = computation_->indexes_multi[1];
2613+
// It is a list of pairs.
2614+
2615+
// This struct appears as an element of the list inside MultiIndexSplitInfo.
2616+
// It helps us describe how we can split up a multi-index (a list of pairs)
2617+
// into a sequence of ranges where the .first value is constant across the
2618+
// range.
2619+
struct SingleSplitInfo {
2620+
// 'offset' is the index into the vector of pairs that forms the
2621+
// start of this range. In the example where we are splitting up
2622+
// ((10,2), (10,3), (10,4), (15,3), (15,5), (15,7))
2623+
// there would be two instances of struct SingleSplitInfo, with
2624+
// offset = 0 and offset = 3.
2625+
int32 offset;
2626+
// 'size' is the number of pairs in this range; in the example
2627+
// above, both 'size' elements would be 3.
2628+
int32 size;
2629+
// first_value is the value of the .first index throughout this range; in
2630+
// the example above, it would be 10 and 15 respectively. It represents a
2631+
// submatrix index.
2632+
int32 first_value;
2633+
2634+
// initial_second_value is the minimum value of .second for any element in
2635+
// this range: it would be 2 and 3 respectively in the example above.
2636+
int32 min_second_value;
2637+
2638+
// second_value_range is the highest value of .second for any element in
2639+
// this range, plus one, minus min_second_value. (It's the number of rows
2640+
// in the other submatrix of the operation).
2641+
int32 second_value_range;
2642+
2643+
// If the .second values in the range are consecutive then
2644+
// 'second_value_offsets' will be empty. Otherwise it will
2645+
// be a vector of size 'size', containing numbers in the
2646+
// range 0 ... second_value_range - 1, such that
2647+
// min_second_value + second_value_offsets[i] gives
2648+
// the .second value at the corresponding position in the range.
2649+
// In the second range of the example above, the range
2650+
// consisting of ((15,3), (15,5), (15,7)), 'second_value_offsets
2651+
// would be the vector (0, 2, 4).
2652+
std::vector<int32> second_value_offsets;
2653+
};
2654+
2655+
// An instance of the struct MultiIndexSplitInfo will be created for each multi-index,
2656+
// i.e. for each element of NnetComputation::indexes_multi.
2657+
struct MultiIndexSplitInfo {
2658+
// If we can split this multi-index into at most two ranges, this
2659+
// vector will be nonempty; otherwise it will be empty.
2660+
std::vector<SingleSplitInfo> splits;
2661+
};
2662+
2663+
// GetSplitInfo() attempts to take a range of a
2664+
// std::vector<std::pair<int32, int32> >, as represented by begin and end
2665+
// iterators, and to extract its information into an object of type
2666+
// SingleSplitInfo. (all except for the .offset member, which will have
2667+
// been set by calling code).
2668+
// It return true if successful, and false otherwise. The only reasons that
2669+
// it might return false are that the range contains -1's or does not contain
2670+
// all-identical .first members).
2671+
bool GetSplitInfo(std::vector<std::pair<int32, int32> >::const_iterator begin,
2672+
std::vector<std::pair<int32, int32> >::const_iterator end,
2673+
SingleSplitInfo *info);
2674+
2675+
// computation_ is the computation that we are modifying.
2676+
NnetComputation *computation_;
2677+
// split_info_ will contain information about how we can split up the members
2678+
// of computation_->indexes_multi into ranges.
2679+
std::vector<MultiIndexSplitInfo> split_info_;
2680+
// The following is a list of additional commands that we are going to insert
2681+
// into computation_, of the form (command-index, command) where command-index
2682+
// is a command index just before which we will insert the new command.
2683+
// (this is the format accepted by the function InsertCommands()).
2684+
std::vector<std::pair<int32, NnetComputation::Command> > new_commands_;
2685+
2686+
};
2687+
2688+
2689+
bool RowOpsSplitter::GetSplitInfo(
2690+
std::vector<std::pair<int32, int32> >::const_iterator begin,
2691+
std::vector<std::pair<int32, int32> >::const_iterator end,
2692+
SingleSplitInfo *info) {
2693+
// max_size_ratio must be > 1.0, and could in principle be a float. It is
2694+
// there to prevent us from making changes to the computation which would end
2695+
// up wastefully launching too many kernels that would do nothing.
2696+
const int32 max_size_ratio = 2;
2697+
2698+
int32 size = end - begin;
2699+
KALDI_ASSERT(size != 0);
2700+
int32 first = begin->first;
2701+
if (first < 0)
2702+
return false;
2703+
info->size = size;
2704+
info->first_value = first;
2705+
int32 initial_second_value = begin->second,
2706+
min_second_value = initial_second_value,
2707+
max_second_value = initial_second_value;
2708+
info->second_value_offsets.resize(size);
2709+
bool is_consecutive = true;
2710+
for (int32 i = 0; i < size; i++) {
2711+
int32 second = begin[i].second;
2712+
if (begin[i].first != first || second < 0) return false;
2713+
info->second_value_offsets[i] = second;
2714+
if (second != initial_second_value + i)
2715+
is_consecutive = false;
2716+
if (second < min_second_value) min_second_value = second;
2717+
if (second > max_second_value) max_second_value = second;
2718+
}
2719+
info->min_second_value = min_second_value;
2720+
info->second_value_range = max_second_value + 1 - min_second_value;
2721+
if (info->second_value_range > size * max_size_ratio)
2722+
return false;
2723+
if (is_consecutive) {
2724+
info->second_value_offsets.clear();
2725+
} else {
2726+
for (int32 i = 0; i < size; i++)
2727+
info->second_value_offsets[i] -= min_second_value;
2728+
}
2729+
return true;
2730+
}
2731+
2732+
2733+
bool RowOpsSplitter::SplitIndexes() {
2734+
bool ans = false;
2735+
int32 num_indexes_multi = computation_->indexes_multi.size();
2736+
split_info_.resize(num_indexes_multi);
2737+
for (int32 i = 0; i < num_indexes_multi; i++) {
2738+
const std::vector<std::pair<int32,int32> > &multi_index =
2739+
computation_->indexes_multi[i];
2740+
MultiIndexSplitInfo &split_info = split_info_[i];
2741+
2742+
int32 num_pairs = multi_index.size();
2743+
KALDI_ASSERT(num_pairs > 0);
2744+
// 'split_point' will be set to the first index j for which
2745+
// multi_index[j-1].first != multi_index[j].first, or -1
2746+
// if no such j exists.
2747+
int32 split_point = -1, initial_first = multi_index[0].first;
2748+
for (int32 j = 1; j < num_pairs; j++) {
2749+
if (multi_index[j].first != initial_first) {
2750+
split_point = j;
2751+
break;
2752+
}
2753+
}
2754+
if (split_point == -1) {
2755+
split_info.splits.resize(1);
2756+
split_info.splits[0].offset = 0;
2757+
if (!GetSplitInfo(multi_index.begin(), multi_index.end(),
2758+
&(split_info.splits[0]))) {
2759+
split_info.splits.clear();
2760+
} else {
2761+
ans = true;
2762+
}
2763+
} else {
2764+
split_info.splits.resize(2);
2765+
split_info.splits[0].offset = 0;
2766+
split_info.splits[1].offset = split_point;
2767+
2768+
std::vector<std::pair<int32,int32> >::const_iterator mid_iter =
2769+
multi_index.begin() + split_point;
2770+
if (!GetSplitInfo(multi_index.begin(), mid_iter,
2771+
&(split_info.splits[0])) ||
2772+
!GetSplitInfo(mid_iter, multi_index.end(),
2773+
&(split_info.splits[1]))) {
2774+
split_info.splits.clear();
2775+
} else {
2776+
ans = true;
2777+
}
2778+
}
2779+
}
2780+
return ans;
2781+
}
2782+
2783+
bool RowOpsSplitter::SplitCommand(int32 c) {
2784+
NnetComputation::Command &command = computation_->commands[c];
2785+
CommandType command_type = command.command_type;
2786+
// For commands that are not of the following four types, return false: we
2787+
// won't be changing these commands.
2788+
switch (command_type) {
2789+
case kAddRowsMulti: case kCopyRowsMulti:
2790+
case kAddToRowsMulti: case kCopyToRowsMulti: break;
2791+
default: return false;
2792+
}
2793+
int32 indexes_multi_index = command.arg2;
2794+
KALDI_ASSERT(indexes_multi_index <
2795+
static_cast<int32>(split_info_.size()));
2796+
const MultiIndexSplitInfo &split_info = split_info_[indexes_multi_index];
2797+
if (split_info.splits.empty())
2798+
return false; // these indexes couldn't be split: e.g. they contained more
2799+
// than two distinct .first elements, or there were other
2800+
// reasons.
2801+
2802+
// we'll be splitting the command into either one or two pieces.
2803+
std::vector<NnetComputation::Command> split_commands(
2804+
split_info.splits.size());
2805+
for (size_t i = 0; i < split_info.splits.size(); i++) {
2806+
const SingleSplitInfo &split = split_info.splits[i];
2807+
NnetComputation::Command &command_out = split_commands[i];
2808+
command_out.alpha = command.alpha;
2809+
command_out.arg1 = computation_->NewSubMatrix(
2810+
command.arg1, split.offset, split.size, 0, -1);
2811+
command_out.arg2 = computation_->NewSubMatrix(
2812+
split.first_value, split.min_second_value,
2813+
split.second_value_range, 0, -1);
2814+
2815+
if (split.second_value_offsets.empty()) {
2816+
// The .second elements are consecutive.
2817+
switch (command_type) {
2818+
case kAddRowsMulti:
2819+
command_out.command_type = kMatrixAdd;
2820+
break;
2821+
case kCopyRowsMulti:
2822+
command_out.command_type = kMatrixCopy;
2823+
break;
2824+
case kAddToRowsMulti:
2825+
command_out.command_type = kMatrixAdd;
2826+
std::swap(command_out.arg1, command_out.arg2);
2827+
break;
2828+
case kCopyToRowsMulti:
2829+
command_out.command_type = kMatrixCopy;
2830+
std::swap(command_out.arg1, command_out.arg2);
2831+
break;
2832+
default: // will never be reached.
2833+
break;
2834+
}
2835+
} else {
2836+
// Indexes are not consecutive: it needs to be a kAddRows or kCopyRows
2837+
// command.
2838+
command_out.arg3 = computation_->indexes.size();
2839+
switch (command_type) {
2840+
case kAddRowsMulti: case kCopyRowsMulti: {
2841+
command_out.command_type = (command_type == kAddRowsMulti ?
2842+
kAddRows : kCopyRows);
2843+
computation_->indexes.push_back(split.second_value_offsets);
2844+
break;
2845+
}
2846+
case kCopyToRowsMulti: {
2847+
// We can't operate on this command because of what would happen
2848+
// with values of 'indexes' (see the variable in the block for
2849+
// kAddToRowsMulti) which were -1. Rows of the output would be
2850+
// set to zero, which is not the behavior we want here; we'd want
2851+
// them to be unaffected.
2852+
return false;
2853+
}
2854+
case kAddToRowsMulti: {
2855+
command_out.command_type = kAddRows;
2856+
std::swap(command_out.arg1, command_out.arg2);
2857+
// invert the indexes.
2858+
std::vector<int32> indexes(split.second_value_range, -1);
2859+
for (int32 i = 0; i < split.size; i++) {
2860+
// the following assert should always succeed because the
2861+
// AddToRowsMulti and CopyToRowsMulti should never have
2862+
// duplicate destinations in their indexes.
2863+
KALDI_ASSERT(indexes[split.second_value_offsets[i]] >= 0);
2864+
indexes[split.second_value_offsets[i]] = i;
2865+
}
2866+
computation_->indexes.push_back(indexes);
2867+
break;
2868+
}
2869+
default:
2870+
KALDI_ERR << "Code error: un-handled case.";
2871+
}
2872+
}
2873+
}
2874+
command = split_commands[0];
2875+
// note: for now, split_commands.size() will be 1 or 2.
2876+
for (size_t i = 1; i < split_commands.size(); i++) {
2877+
new_commands_.resize(new_commands_.size() + 1);
2878+
// we'll want to insert this command right after command c,
2879+
// which is the same as just before command c + 1.
2880+
new_commands_.back().first = c + 1;
2881+
new_commands_.back().second = split_commands[i];
2882+
}
2883+
return true; // We made a change.
2884+
}
2885+
2886+
bool RowOpsSplitter::SplitCommands() {
2887+
bool ans = false;
2888+
int32 num_commands = computation_->commands.size();
2889+
for (int32 c = 0; c < num_commands; c++)
2890+
if (SplitCommand(c))
2891+
ans = true;
2892+
if (!new_commands_.empty())
2893+
InsertCommands(&new_commands_, computation_);
2894+
return ans;
2895+
}
2896+
2897+
bool SplitRowOps(NnetComputation *computation) {
2898+
RowOpsSplitter splitter(computation);
2899+
return splitter.Split();
2900+
}
25792901

25802902

25812903
/*

‎src/nnet3/nnet-optimize-utils.h

+29-12
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,23 @@ bool ReplaceRowWithMatrixOps(NnetComputation *computation);
455455
/// computation->indexes.
456456
bool SnipRowOps(NnetComputation *computation);
457457

458+
459+
/// This function detects cases where commands of type kAddRowsMulti,
460+
/// kAddToRowsMulti, kCopyRowsMulti, kCopyToRowsMulti use indexes that
461+
/// correspond to at most two submatrices, in two distinct ranges without gaps
462+
/// filled by -1's, and could be converted to at most two commands of type
463+
/// kMatrixAdd, kMatrixCopy, kAddRows or kCopyRows. (Note: it's important that
464+
/// this optimization takes place after SnipRowOps, because it doesn't remove
465+
/// the -1's from the edges of the indexes, it relies on that operation doing
466+
/// so). The "without-gaps" stipulation is just for convenience of
467+
/// implementation, to have fewer cases to worry about.
468+
///
469+
/// This function returns true if it made any changes to the computation; if it
470+
/// returns true, then after calling this you should at some point do
471+
/// RenumberComputation(), which will remove any now-unused members of
472+
/// computation->indexes.
473+
bool SplitRowOps(NnetComputation *computation);
474+
458475
/// This function detects submatrices and matrices that are never used (e.g. due
459476
/// to changes made in other optimization code), and members of indexes,
460477
/// indexes_multi and indexes_ranges that are unused or are duplicates, and memo
@@ -535,18 +552,18 @@ void IdentifyIndexesRangesArgs(std::vector<NnetComputation::Command> *commands,
535552
std::vector<int32*> *indexes_ranges_args);
536553

537554
/// Inserts commands into the computation at the requested places. 'commands'
538-
/// is a list of pairs (command-index, command) that is expected to be sorted
539-
/// on command-index. For each entry (c, command) in 'commands', 'command' is
540-
/// inserted into 'computation' just *before* the command that (at entry) is in
541-
/// computation->commands[c]. If there are multiple pairs with the same index
542-
/// c, they will remain in the same order in which they were present in
543-
/// 'commands'; however, 'commands' does not have to be sorted on 'c'.
544-
/// As a special case, if c == computation->commands.size(), the
545-
/// corresponding commands are inserted at the beginning of the computation.
546-
/// This function will appropriately renumber the argument of the kGotoLabel
547-
/// command of any 'looped' computation. Command indexes c in commands[*].first
548-
/// must be in the range [0, computation->commands.size()].
549-
/// This function may modify 'commands' by sorting it.
555+
/// is a list of pairs (command-index, command) that is expected to be sorted on
556+
/// command-index. For each entry (c, command) in 'commands', 'command' is
557+
/// inserted into 'computation' just *before* the command that (at entry) is in
558+
/// computation->commands[c]. If there are multiple pairs with the same index
559+
/// c, they will remain in the same order in which they were present in
560+
/// 'commands'; however, 'commands' does not have to be sorted on 'c'. As a
561+
/// special case, if c == computation->commands.size(), the corresponding
562+
/// commands are inserted at the beginning of the computation. This function
563+
/// will appropriately renumber the argument of the kGotoLabel command of any
564+
/// 'looped' computation. Command indexes c in commands[*].first must be in the
565+
/// range [0, computation->commands.size()]. This function may modify
566+
/// 'commands' by sorting it.
550567
void InsertCommands(
551568
std::vector<std::pair<int32, NnetComputation::Command> > *commands,
552569
NnetComputation *computation);

‎src/nnet3/nnet-optimize.cc

+13-1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ void NnetOptimizeOptions::Read(std::istream &is, bool binary) {
4141
if (tok == "<OptimizeRowOps>") {
4242
ReadBasicType(is, binary, &optimize_row_ops);
4343
ReadToken(is, binary, &tok);
44+
} else {
45+
optimize_row_ops = true;
46+
}
47+
if (tok == "<SplitRowOps>") {
48+
ReadBasicType(is, binary, &split_row_ops);
49+
ReadToken(is, binary, &tok);
50+
} else {
51+
split_row_ops = true;
4452
}
4553
KALDI_ASSERT(tok == "<ConvertAddition>");
4654
ReadBasicType(is, binary, &convert_addition);
@@ -516,12 +524,16 @@ void Optimize(const NnetOptimizeOptions &config,
516524
}
517525

518526

519-
if (config.optimize && (config.snip_row_ops || config.optimize_row_ops)) {
527+
if (config.optimize && (config.snip_row_ops || config.optimize_row_ops ||
528+
config.split_row_ops)) {
520529
bool must_renumber = false;
521530
if (config.snip_row_ops && SnipRowOps(computation))
522531
must_renumber = true;
532+
if (config.split_row_ops && SplitRowOps(computation))
533+
must_renumber = true;
523534
if (config.optimize_row_ops && ReplaceRowWithMatrixOps(computation))
524535
must_renumber = true;
536+
525537
if (must_renumber) {
526538
RenumberComputation(computation);
527539
if (GetVerboseLevel() >= 3)

‎src/nnet3/nnet-optimize.h

+6
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct NnetOptimizeOptions {
3939
bool propagate_in_place;
4040
bool backprop_in_place;
4141
bool optimize_row_ops;
42+
bool split_row_ops;
4243
bool extend_matrices;
4344
bool convert_addition;
4445
bool remove_assignments;
@@ -63,6 +64,7 @@ struct NnetOptimizeOptions {
6364
propagate_in_place(true),
6465
backprop_in_place(true),
6566
optimize_row_ops(true),
67+
split_row_ops(true),
6668
extend_matrices(true),
6769
convert_addition(true),
6870
remove_assignments(true),
@@ -95,6 +97,10 @@ struct NnetOptimizeOptions {
9597
opts->Register("optimize-row-ops", &optimize_row_ops, "Set to false to "
9698
"disable certain optimizations that act on operations of "
9799
"type *Row*.");
100+
opts->Register("split-row-ops", &split_row_ops, "Set to false to disable "
101+
"an optimization that may replace some operations of type "
102+
"kCopyRowsMulti or kAddRowsMulti with up to two simpler "
103+
"operations.");
98104
opts->Register("convert-addition", &convert_addition, "Set to false to "
99105
"disable the optimization that converts Add commands into "
100106
"Copy commands wherever possible.");

‎src/nnet3/nnet-utils.cc

-12
Original file line numberDiff line numberDiff line change
@@ -1867,19 +1867,7 @@ class ModelCollapser {
18671867
void CollapseModel(const CollapseModelConfig &config,
18681868
Nnet *nnet) {
18691869
ModelCollapser c(config, nnet);
1870-
std::string info_before_collapse;
1871-
if (GetVerboseLevel() >= 4)
1872-
info_before_collapse = nnet->Info();
18731870
c.Collapse();
1874-
if (GetVerboseLevel() >= 4) {
1875-
std::string info_after_collapse = nnet->Info();
1876-
if (info_after_collapse != info_before_collapse) {
1877-
KALDI_VLOG(4) << "Collapsing model: info before collapse was: "
1878-
<< info_before_collapse
1879-
<< ", info after collapse was:"
1880-
<< info_after_collapse;
1881-
}
1882-
}
18831871
}
18841872

18851873
bool UpdateNnetWithMaxChange(const Nnet &delta_nnet,

0 commit comments

Comments
 (0)
Please sign in to comment.