@@ -2576,6 +2576,328 @@ bool SnipRowOps(NnetComputation *computation) {
2576
2576
2577
2577
2578
2578
2579
+ // This class implements the internals of the function SplitRowOps() which is
2580
+ // declared in nnet-optimize-utils.h.
2581
+ class RowOpsSplitter {
2582
+ public:
2583
+ RowOpsSplitter (NnetComputation *computation): computation_(computation) { }
2584
+
2585
+ // Attempts to perform the optimization. Returns true if it made any change
2586
+ // to the computation.
2587
+ bool Split () {
2588
+ return SplitIndexes () && SplitCommands ();
2589
+ }
2590
+
2591
+ private:
2592
+
2593
+ // This function sets up split_info_, which describes how we can split up
2594
+ // the vectors that are elements of computation_->indexes_multi.
2595
+ // It will return true if it successfully split at least one of those
2596
+ // vectors, and false otherwise.
2597
+ bool SplitIndexes ();
2598
+
2599
+ // This function modifies the commands in the computation. It returns
2600
+ // true if it made any change.
2601
+ bool SplitCommands ();
2602
+
2603
+
2604
+ // This function attempts to optimize the command in
2605
+ // computation_->commands[command_index]. It returns true if it made any
2606
+ // change. If we are going to have to insert an extra command into the
2607
+ // computation, this function will append an element to new_commands_.
2608
+ bool SplitCommand (int32 command_index);
2609
+
2610
+ // Below, define a multi-index as an element of NnetComputation::indexes_multi,
2611
+ // for example,
2612
+ // const std::vector<std::pair<int32,int32> > &multi_index = computation_->indexes_multi[1];
2613
+ // It is a list of pairs.
2614
+
2615
+ // This struct appears as an element of the list inside MultiIndexSplitInfo.
2616
+ // It helps us describe how we can split up a multi-index (a list of pairs)
2617
+ // into a sequence of ranges where the .first value is constant across the
2618
+ // range.
2619
+ struct SingleSplitInfo {
2620
+ // 'offset' is the index into the vector of pairs that forms the
2621
+ // start of this range. In the example where we are splitting up
2622
+ // ((10,2), (10,3), (10,4), (15,3), (15,5), (15,7))
2623
+ // there would be two instances of struct SingleSplitInfo, with
2624
+ // offset = 0 and offset = 3.
2625
+ int32 offset;
2626
+ // 'size' is the number of pairs in this range; in the example
2627
+ // above, both 'size' elements would be 3.
2628
+ int32 size;
2629
+ // first_value is the value of the .first index throughout this range; in
2630
+ // the example above, it would be 10 and 15 respectively. It represents a
2631
+ // submatrix index.
2632
+ int32 first_value;
2633
+
2634
+ // initial_second_value is the minimum value of .second for any element in
2635
+ // this range: it would be 2 and 3 respectively in the example above.
2636
+ int32 min_second_value;
2637
+
2638
+ // second_value_range is the highest value of .second for any element in
2639
+ // this range, plus one, minus min_second_value. (It's the number of rows
2640
+ // in the other submatrix of the operation).
2641
+ int32 second_value_range;
2642
+
2643
+ // If the .second values in the range are consecutive then
2644
+ // 'second_value_offsets' will be empty. Otherwise it will
2645
+ // be a vector of size 'size', containing numbers in the
2646
+ // range 0 ... second_value_range - 1, such that
2647
+ // min_second_value + second_value_offsets[i] gives
2648
+ // the .second value at the corresponding position in the range.
2649
+ // In the second range of the example above, the range
2650
+ // consisting of ((15,3), (15,5), (15,7)), 'second_value_offsets
2651
+ // would be the vector (0, 2, 4).
2652
+ std::vector<int32> second_value_offsets;
2653
+ };
2654
+
2655
+ // An instance of the struct MultiIndexSplitInfo will be created for each multi-index,
2656
+ // i.e. for each element of NnetComputation::indexes_multi.
2657
+ struct MultiIndexSplitInfo {
2658
+ // If we can split this multi-index into at most two ranges, this
2659
+ // vector will be nonempty; otherwise it will be empty.
2660
+ std::vector<SingleSplitInfo> splits;
2661
+ };
2662
+
2663
+ // GetSplitInfo() attempts to take a range of a
2664
+ // std::vector<std::pair<int32, int32> >, as represented by begin and end
2665
+ // iterators, and to extract its information into an object of type
2666
+ // SingleSplitInfo. (all except for the .offset member, which will have
2667
+ // been set by calling code).
2668
+ // It return true if successful, and false otherwise. The only reasons that
2669
+ // it might return false are that the range contains -1's or does not contain
2670
+ // all-identical .first members).
2671
+ bool GetSplitInfo (std::vector<std::pair<int32, int32> >::const_iterator begin,
2672
+ std::vector<std::pair<int32, int32> >::const_iterator end,
2673
+ SingleSplitInfo *info);
2674
+
2675
+ // computation_ is the computation that we are modifying.
2676
+ NnetComputation *computation_;
2677
+ // split_info_ will contain information about how we can split up the members
2678
+ // of computation_->indexes_multi into ranges.
2679
+ std::vector<MultiIndexSplitInfo> split_info_;
2680
+ // The following is a list of additional commands that we are going to insert
2681
+ // into computation_, of the form (command-index, command) where command-index
2682
+ // is a command index just before which we will insert the new command.
2683
+ // (this is the format accepted by the function InsertCommands()).
2684
+ std::vector<std::pair<int32, NnetComputation::Command> > new_commands_;
2685
+
2686
+ };
2687
+
2688
+
2689
+ bool RowOpsSplitter::GetSplitInfo (
2690
+ std::vector<std::pair<int32, int32> >::const_iterator begin,
2691
+ std::vector<std::pair<int32, int32> >::const_iterator end,
2692
+ SingleSplitInfo *info) {
2693
+ // max_size_ratio must be > 1.0, and could in principle be a float. It is
2694
+ // there to prevent us from making changes to the computation which would end
2695
+ // up wastefully launching too many kernels that would do nothing.
2696
+ const int32 max_size_ratio = 2 ;
2697
+
2698
+ int32 size = end - begin;
2699
+ KALDI_ASSERT (size != 0 );
2700
+ int32 first = begin->first ;
2701
+ if (first < 0 )
2702
+ return false ;
2703
+ info->size = size;
2704
+ info->first_value = first;
2705
+ int32 initial_second_value = begin->second ,
2706
+ min_second_value = initial_second_value,
2707
+ max_second_value = initial_second_value;
2708
+ info->second_value_offsets .resize (size);
2709
+ bool is_consecutive = true ;
2710
+ for (int32 i = 0 ; i < size; i++) {
2711
+ int32 second = begin[i].second ;
2712
+ if (begin[i].first != first || second < 0 ) return false ;
2713
+ info->second_value_offsets [i] = second;
2714
+ if (second != initial_second_value + i)
2715
+ is_consecutive = false ;
2716
+ if (second < min_second_value) min_second_value = second;
2717
+ if (second > max_second_value) max_second_value = second;
2718
+ }
2719
+ info->min_second_value = min_second_value;
2720
+ info->second_value_range = max_second_value + 1 - min_second_value;
2721
+ if (info->second_value_range > size * max_size_ratio)
2722
+ return false ;
2723
+ if (is_consecutive) {
2724
+ info->second_value_offsets .clear ();
2725
+ } else {
2726
+ for (int32 i = 0 ; i < size; i++)
2727
+ info->second_value_offsets [i] -= min_second_value;
2728
+ }
2729
+ return true ;
2730
+ }
2731
+
2732
+
2733
+ bool RowOpsSplitter::SplitIndexes () {
2734
+ bool ans = false ;
2735
+ int32 num_indexes_multi = computation_->indexes_multi .size ();
2736
+ split_info_.resize (num_indexes_multi);
2737
+ for (int32 i = 0 ; i < num_indexes_multi; i++) {
2738
+ const std::vector<std::pair<int32,int32> > &multi_index =
2739
+ computation_->indexes_multi [i];
2740
+ MultiIndexSplitInfo &split_info = split_info_[i];
2741
+
2742
+ int32 num_pairs = multi_index.size ();
2743
+ KALDI_ASSERT (num_pairs > 0 );
2744
+ // 'split_point' will be set to the first index j for which
2745
+ // multi_index[j-1].first != multi_index[j].first, or -1
2746
+ // if no such j exists.
2747
+ int32 split_point = -1 , initial_first = multi_index[0 ].first ;
2748
+ for (int32 j = 1 ; j < num_pairs; j++) {
2749
+ if (multi_index[j].first != initial_first) {
2750
+ split_point = j;
2751
+ break ;
2752
+ }
2753
+ }
2754
+ if (split_point == -1 ) {
2755
+ split_info.splits .resize (1 );
2756
+ split_info.splits [0 ].offset = 0 ;
2757
+ if (!GetSplitInfo (multi_index.begin (), multi_index.end (),
2758
+ &(split_info.splits [0 ]))) {
2759
+ split_info.splits .clear ();
2760
+ } else {
2761
+ ans = true ;
2762
+ }
2763
+ } else {
2764
+ split_info.splits .resize (2 );
2765
+ split_info.splits [0 ].offset = 0 ;
2766
+ split_info.splits [1 ].offset = split_point;
2767
+
2768
+ std::vector<std::pair<int32,int32> >::const_iterator mid_iter =
2769
+ multi_index.begin () + split_point;
2770
+ if (!GetSplitInfo (multi_index.begin (), mid_iter,
2771
+ &(split_info.splits [0 ])) ||
2772
+ !GetSplitInfo (mid_iter, multi_index.end (),
2773
+ &(split_info.splits [1 ]))) {
2774
+ split_info.splits .clear ();
2775
+ } else {
2776
+ ans = true ;
2777
+ }
2778
+ }
2779
+ }
2780
+ return ans;
2781
+ }
2782
+
2783
+ bool RowOpsSplitter::SplitCommand (int32 c) {
2784
+ NnetComputation::Command &command = computation_->commands [c];
2785
+ CommandType command_type = command.command_type ;
2786
+ // For commands that are not of the following four types, return false: we
2787
+ // won't be changing these commands.
2788
+ switch (command_type) {
2789
+ case kAddRowsMulti : case kCopyRowsMulti :
2790
+ case kAddToRowsMulti : case kCopyToRowsMulti : break ;
2791
+ default : return false ;
2792
+ }
2793
+ int32 indexes_multi_index = command.arg2 ;
2794
+ KALDI_ASSERT (indexes_multi_index <
2795
+ static_cast <int32>(split_info_.size ()));
2796
+ const MultiIndexSplitInfo &split_info = split_info_[indexes_multi_index];
2797
+ if (split_info.splits .empty ())
2798
+ return false ; // these indexes couldn't be split: e.g. they contained more
2799
+ // than two distinct .first elements, or there were other
2800
+ // reasons.
2801
+
2802
+ // we'll be splitting the command into either one or two pieces.
2803
+ std::vector<NnetComputation::Command> split_commands (
2804
+ split_info.splits .size ());
2805
+ for (size_t i = 0 ; i < split_info.splits .size (); i++) {
2806
+ const SingleSplitInfo &split = split_info.splits [i];
2807
+ NnetComputation::Command &command_out = split_commands[i];
2808
+ command_out.alpha = command.alpha ;
2809
+ command_out.arg1 = computation_->NewSubMatrix (
2810
+ command.arg1 , split.offset , split.size , 0 , -1 );
2811
+ command_out.arg2 = computation_->NewSubMatrix (
2812
+ split.first_value , split.min_second_value ,
2813
+ split.second_value_range , 0 , -1 );
2814
+
2815
+ if (split.second_value_offsets .empty ()) {
2816
+ // The .second elements are consecutive.
2817
+ switch (command_type) {
2818
+ case kAddRowsMulti :
2819
+ command_out.command_type = kMatrixAdd ;
2820
+ break ;
2821
+ case kCopyRowsMulti :
2822
+ command_out.command_type = kMatrixCopy ;
2823
+ break ;
2824
+ case kAddToRowsMulti :
2825
+ command_out.command_type = kMatrixAdd ;
2826
+ std::swap (command_out.arg1 , command_out.arg2 );
2827
+ break ;
2828
+ case kCopyToRowsMulti :
2829
+ command_out.command_type = kMatrixCopy ;
2830
+ std::swap (command_out.arg1 , command_out.arg2 );
2831
+ break ;
2832
+ default : // will never be reached.
2833
+ break ;
2834
+ }
2835
+ } else {
2836
+ // Indexes are not consecutive: it needs to be a kAddRows or kCopyRows
2837
+ // command.
2838
+ command_out.arg3 = computation_->indexes .size ();
2839
+ switch (command_type) {
2840
+ case kAddRowsMulti : case kCopyRowsMulti : {
2841
+ command_out.command_type = (command_type == kAddRowsMulti ?
2842
+ kAddRows : kCopyRows );
2843
+ computation_->indexes .push_back (split.second_value_offsets );
2844
+ break ;
2845
+ }
2846
+ case kCopyToRowsMulti : {
2847
+ // We can't operate on this command because of what would happen
2848
+ // with values of 'indexes' (see the variable in the block for
2849
+ // kAddToRowsMulti) which were -1. Rows of the output would be
2850
+ // set to zero, which is not the behavior we want here; we'd want
2851
+ // them to be unaffected.
2852
+ return false ;
2853
+ }
2854
+ case kAddToRowsMulti : {
2855
+ command_out.command_type = kAddRows ;
2856
+ std::swap (command_out.arg1 , command_out.arg2 );
2857
+ // invert the indexes.
2858
+ std::vector<int32> indexes (split.second_value_range , -1 );
2859
+ for (int32 i = 0 ; i < split.size ; i++) {
2860
+ // the following assert should always succeed because the
2861
+ // AddToRowsMulti and CopyToRowsMulti should never have
2862
+ // duplicate destinations in their indexes.
2863
+ KALDI_ASSERT (indexes[split.second_value_offsets [i]] >= 0 );
2864
+ indexes[split.second_value_offsets [i]] = i;
2865
+ }
2866
+ computation_->indexes .push_back (indexes);
2867
+ break ;
2868
+ }
2869
+ default :
2870
+ KALDI_ERR << " Code error: un-handled case." ;
2871
+ }
2872
+ }
2873
+ }
2874
+ command = split_commands[0 ];
2875
+ // note: for now, split_commands.size() will be 1 or 2.
2876
+ for (size_t i = 1 ; i < split_commands.size (); i++) {
2877
+ new_commands_.resize (new_commands_.size () + 1 );
2878
+ // we'll want to insert this command right after command c,
2879
+ // which is the same as just before command c + 1.
2880
+ new_commands_.back ().first = c + 1 ;
2881
+ new_commands_.back ().second = split_commands[i];
2882
+ }
2883
+ return true ; // We made a change.
2884
+ }
2885
+
2886
+ bool RowOpsSplitter::SplitCommands () {
2887
+ bool ans = false ;
2888
+ int32 num_commands = computation_->commands .size ();
2889
+ for (int32 c = 0 ; c < num_commands; c++)
2890
+ if (SplitCommand (c))
2891
+ ans = true ;
2892
+ if (!new_commands_.empty ())
2893
+ InsertCommands (&new_commands_, computation_);
2894
+ return ans;
2895
+ }
2896
+
2897
+ bool SplitRowOps (NnetComputation *computation) {
2898
+ RowOpsSplitter splitter (computation);
2899
+ return splitter.Split ();
2900
+ }
2579
2901
2580
2902
2581
2903
/*
0 commit comments