Skip to content

Commit 84318e2

Browse files
committed
[ET-VK] Adding all tensor packing support to split op.
This diff updates Executorch Vulkan backend's `split` operation to support width, height and channel packed tensors. It also updates the op_registry.py file to indicate `split` operation supports all packing and adds new test cases to the cases.py file to test the operation. Differential Revision: [D71345589](https://our.internmc.facebook.com/intern/diff/D71345589/) ghstack-source-id: 272306677 Pull Request resolved: #9345
1 parent 85b341e commit 84318e2

File tree

3 files changed

+111
-98
lines changed

3 files changed

+111
-98
lines changed

backends/vulkan/op_registry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -528,8 +528,6 @@ def register_view_op(features: OpFeatures):
528528
exir_ops.edge.aten.index_select.default,
529529
exir_ops.edge.aten.select_copy.int,
530530
# Tensor combination
531-
exir_ops.edge.aten.split_with_sizes_copy.default,
532-
exir_ops.edge.aten.split.Tensor,
533531
exir_ops.edge.aten.repeat.default,
534532
# Tensor creation
535533
exir_ops.edge.aten.arange.start_step,
@@ -563,6 +561,8 @@ def register_ported_op(features: OpFeatures):
563561
exir_ops.edge.aten.permute_copy.default,
564562
# Tensor combination
565563
exir_ops.edge.aten.cat.default,
564+
exir_ops.edge.aten.split_with_sizes_copy.default,
565+
exir_ops.edge.aten.split.Tensor,
566566
]
567567
)
568568
def register_ported_op_all_packed_dims(features: OpFeatures):

backends/vulkan/runtime/graph/ops/impl/Split.cpp

+43-47
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ void add_split_with_sizes_default_node(
2525
ValueRef out_list_ref) {
2626
vTensorPtr t_in = graph.get_tensor(in);
2727

28-
VK_CHECK_COND(check_packed_dim_is(*t_in, WHCN::kChannelsDim));
29-
3028
ValueListPtr out_list = graph.get_value_list(out_list_ref);
3129

3230
DimIndex dim_index = normalize_to_dim_index(*t_in, dim);
@@ -38,62 +36,60 @@ void add_split_with_sizes_default_node(
3836
ValueRef out_ref = (*out_list)[split_idx];
3937

4038
vTensorPtr t_out = graph.get_tensor(out_ref);
41-
VK_CHECK_COND(check_packed_dim_is(*t_out, WHCN::kChannelsDim));
4239
VK_CHECK_COND(dim_at(*t_out, dim_index) == split_size);
4340
}
4441

45-
if (dim_index == kWidth4D) {
46-
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
47-
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
42+
const auto packed_dim = t_in->packed_dim();
43+
const auto packed_dim_index = static_cast<DimIndex>(kWidth4D - packed_dim);
4844

49-
for (ValueRef out_ref : *out_list) {
50-
// Doesn't need to use split_size since we have already verified that the
51-
// output tensor's size matches with the split_size.
52-
vTensorPtr t_out = graph.get_tensor(out_ref);
53-
utils::ivec3 range = t_out->logical_limits();
54-
add_copy_offset_node(
55-
graph, in, range, src_offset, dst_offset, out_ref, false, true);
45+
// Index of dimension to be concatenated in (w, h, c * b) coordinate system
46+
const auto dim_xyz_index = std::min(2, -dim_index - 1);
5647

57-
src_offset[0] += range[0];
58-
}
59-
} else if (dim_index == kHeight4D) {
60-
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
61-
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
48+
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
49+
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
6250

63-
for (ValueRef out_ref : *out_list) {
64-
vTensorPtr t_out = graph.get_tensor(out_ref);
65-
utils::ivec3 range = t_out->logical_limits();
66-
add_copy_offset_node(
67-
graph, in, range, src_offset, dst_offset, out_ref, false, true);
51+
const bool is_splitting_channel = (dim_index == kChannel4D);
6852

69-
src_offset[1] += range[1];
70-
}
71-
} else if (dim_index == kBatch4D) {
72-
utils::ivec4 src_offset = utils::make_ivec4({0, 0, 0, 0}, false);
73-
utils::ivec4 dst_offset = utils::make_ivec4({0, 0, 0, 0}, false);
53+
// if splitting channels
54+
if (is_splitting_channel) {
55+
// set source offset w as channel size of the input tensor
56+
src_offset[3] = dim_at(t_in->sizes(), kChannel4D);
57+
}
7458

75-
for (ValueRef out_ref : *out_list) {
76-
vTensorPtr t_out = graph.get_tensor(out_ref);
77-
utils::ivec3 range = t_out->logical_limits();
59+
for (ValueRef out_ref : *out_list) {
60+
// Doesn't need to use split_size since we have already verified that the
61+
// output tensor's size matches with the split_size.
62+
vTensorPtr t_out = graph.get_tensor(out_ref);
63+
const auto out_channel_size = dim_at(t_out->sizes(), kChannel4D);
64+
utils::ivec3 range = t_out->logical_limits();
65+
66+
if (dim_index == packed_dim_index) {
67+
// if splitting channels, use add_copy_channel_offset_node function as
68+
// add_copy_packed_dim_offset_node does not support channel packing
69+
if (is_splitting_channel) {
70+
add_copy_channel_offset_node(
71+
graph, in, out_channel_size, src_offset[2], dst_offset[2], out_ref);
72+
src_offset[dim_xyz_index] += out_channel_size;
73+
} else {
74+
// dst_offset[3] is not used now but will be used in the future when
75+
// add_copy_packed_dim_offset_node will support channel packing
76+
//
77+
// set destination offset w as channel size of the output tensor if
78+
// splitting channel
79+
dst_offset[3] = is_splitting_channel ? out_channel_size : 0;
80+
add_copy_packed_dim_offset_node(
81+
graph, in, range, src_offset, dst_offset, out_ref);
82+
src_offset[dim_xyz_index] += dim_at(t_out->sizes(), packed_dim_index);
83+
}
84+
} else {
85+
// set destination offset w as channel size of the output tensor if
86+
// splitting channels
87+
dst_offset[3] = is_splitting_channel ? out_channel_size : 0;
7888
add_copy_offset_node(
7989
graph, in, range, src_offset, dst_offset, out_ref, false, true);
80-
81-
src_offset[2] += range[2];
82-
}
83-
} else if (dim_index == kChannel4D) {
84-
int32_t src_offset = 0;
85-
int32_t dst_offset = 0;
86-
87-
for (ValueRef out_ref : *out_list) {
88-
vTensorPtr t_out = graph.get_tensor(out_ref);
89-
int32_t range = dim_at<kChannel4D>(t_out->sizes());
90-
add_copy_channel_offset_node(
91-
graph, in, range, src_offset, dst_offset, out_ref);
92-
src_offset += range;
90+
src_offset[dim_xyz_index] +=
91+
is_splitting_channel ? out_channel_size : range[dim_xyz_index];
9392
}
94-
95-
} else {
96-
VK_THROW("not ipmlemented");
9793
}
9894
}
9995

backends/vulkan/test/op_tests/cases.py

+66-49
Original file line numberDiff line numberDiff line change
@@ -922,86 +922,103 @@ def get_split_with_sizes_inputs():
922922
Test = namedtuple("VkSliceTest", ["self", "sizes", "dim"])
923923
test_cases = [
924924
# Split on Width
925+
Test(self=(S1, 7, 10, 11), sizes=[1, 3, 3, 5], dim=3),
925926
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=3),
927+
Test(self=(7, 10, 11), sizes=[1, 3, 3, 5], dim=2),
926928
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
929+
Test(self=(7, 10, 11), sizes=[3, 8], dim=2),
927930
Test(self=(7, 10, 10), sizes=[1, 9], dim=2),
928931
Test(self=(10, 10), sizes=[1, 9], dim=1),
929932
Test(self=(10,), sizes=[1, 9], dim=0),
930933
# Split on Height
934+
Test(self=(S1, 7, 11, 10), sizes=[1, 3, 3, 5], dim=2),
931935
Test(self=(S1, 7, 10, 10), sizes=[1, 2, 3, 4], dim=2),
936+
Test(self=(7, 11, 10), sizes=[1, 3, 3, 5], dim=1),
932937
Test(self=(7, 10, 10), sizes=[1, 2, 3, 4], dim=1),
938+
Test(self=(7, 11, 11), sizes=[3, 8], dim=1),
933939
Test(self=(7, 10, 10), sizes=[10], dim=1),
934940
Test(self=(7, 6, 10), sizes=[1, 1, 1, 1, 1, 1], dim=1),
935941
Test(self=(10, 10), sizes=[1, 2, 3, 4], dim=0),
936942
# Split on Batch
937943
Test(self=(10, 7, 10, 10), sizes=[3, 6, 1], dim=0),
938944
Test(self=(10, 7, 10, 10), sizes=[10], dim=0),
939945
# Split on Channel
946+
Test(self=(7, 13, 4, 8), sizes=[3, 5, 2, 3], dim=1),
940947
Test(self=(7, 13, 4, 8), sizes=[3, 6, 1, 3], dim=1),
948+
Test(self=(7, 13, 4, 8), sizes=[3, 3, 2, 5, 1], dim=1),
941949
Test(self=(7, 13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=1),
950+
Test(self=(13, 4, 8), sizes=[3, 5, 2, 1, 2], dim=0),
942951
Test(self=(13, 4, 8), sizes=[3, 3, 3, 3, 1], dim=0),
943952
Test(self=(13, 4, 8), sizes=[2, 9, 2], dim=0),
944953
Test(self=(13, 4, 8), sizes=[13], dim=0),
945954
]
946955
test_suite = VkTestSuite([tuple(tc) for tc in test_cases])
947956

948957
test_suite.layouts = [
958+
"utils::kWidthPacked",
959+
"utils::kHeightPacked",
949960
"utils::kChannelsPacked",
950961
]
951962
test_suite.data_gen = "make_seq_tensor"
952963
test_suite.dtypes = ["at::kFloat"]
953964
return test_suite
954965

955966

956-
@register_test_suite("aten.split.Tensor")
957-
def get_split_tensor_inputs():
958-
test_suite = VkTestSuite(
959-
[
960-
# Split on Width
961-
((S1, 7, 10, 12), 12, 3),
962-
((S1, 7, 10, 12), 3, 3),
963-
((S1, 7, 10, 12), 1, 3),
964-
((7, 10, 12), 12, 2),
965-
((7, 10, 12), 3, 2),
966-
((7, 10, 12), 1, 2),
967-
((10, 12), 12, 1),
968-
((10, 12), 3, 1),
969-
((10, 12), 1, 1),
970-
((12,), 12, 0),
971-
((12,), 3, 0),
972-
((12,), 1, 0),
973-
# Split on Height
974-
((S1, 7, 12, 8), 12, 2),
975-
((S1, 7, 12, 8), 3, 2),
976-
((S1, 7, 12, 8), 1, 2),
977-
((7, 12, 8), 12, 1),
978-
((7, 12, 8), 3, 1),
979-
((7, 12, 8), 1, 1),
980-
((12, 8), 12, 0),
981-
((12, 8), 3, 0),
982-
((12, 8), 1, 0),
983-
# Split on Batch
984-
((12, 7, 10, 10), 12, 0),
985-
((12, 7, 10, 10), 3, 0),
986-
((12, 7, 10, 10), 1, 0),
987-
# Split on Channel
988-
((7, 15, 10, 10), 15, 1),
989-
((7, 15, 10, 10), 5, 1),
990-
((7, 15, 10, 10), 3, 1),
991-
((7, 15, 10, 10), 1, 1),
992-
((15, 10, 10), 15, 0),
993-
((15, 10, 10), 5, 0),
994-
((15, 10, 10), 3, 0),
995-
((15, 10, 10), 1, 0),
996-
]
997-
)
998-
999-
test_suite.layouts = [
1000-
"utils::kChannelsPacked",
1001-
]
1002-
test_suite.data_gen = "make_seq_tensor"
1003-
test_suite.dtypes = ["at::kFloat"]
1004-
return test_suite
967+
# @register_test_suite("aten.split.Tensor")
968+
# def get_split_tensor_inputs():
969+
# test_suite = VkTestSuite(
970+
# [
971+
# # Split on Width
972+
# ((M1, 7, 10, 12), 12, 3),
973+
# ((S1, 7, 10, 12), 12, 3),
974+
# ((M1, 7, 10, 12), 3, 3),
975+
# ((S1, 7, 10, 12), 3, 3),
976+
# ((M1, 7, 10, 12), 1, 3),
977+
# ((S1, 7, 10, 12), 1, 3),
978+
# ((7, 10, 12), 12, 2),
979+
# ((7, 10, 12), 3, 2),
980+
# ((7, 10, 12), 1, 2),
981+
# ((2, 3, 4), 1, 2),
982+
# ((10, 12), 12, 1),
983+
# ((10, 12), 3, 1),
984+
# ((10, 12), 1, 1),
985+
# ((12,), 12, 0),
986+
# ((12,), 3, 0),
987+
# ((12,), 1, 0),
988+
# # Split on Height
989+
# ((S1, 7, 12, 8), 12, 2),
990+
# ((S1, 7, 12, 8), 3, 2),
991+
# ((S1, 7, 12, 8), 1, 2),
992+
# ((7, 12, 8), 12, 1),
993+
# ((7, 12, 8), 3, 1),
994+
# ((7, 12, 8), 1, 1),
995+
# ((12, 8), 12, 0),
996+
# ((12, 8), 3, 0),
997+
# ((12, 8), 1, 0),
998+
# # Split on Batch
999+
# ((12, 7, 10, 10), 12, 0),
1000+
# ((12, 7, 10, 10), 3, 0),
1001+
# ((12, 7, 10, 10), 1, 0),
1002+
# # Split on Channel
1003+
# ((7, 15, 10, 10), 15, 1),
1004+
# ((7, 15, 10, 10), 5, 1),
1005+
# ((7, 15, 10, 10), 3, 1),
1006+
# ((7, 15, 10, 10), 1, 1),
1007+
# ((15, 10, 10), 15, 0),
1008+
# ((15, 10, 10), 5, 0),
1009+
# ((15, 10, 10), 3, 0),
1010+
# ((15, 10, 10), 1, 0),
1011+
# ]
1012+
# )
1013+
1014+
# test_suite.layouts = [
1015+
# "utils::kWidthPacked",
1016+
# "utils::kHeightPacked",
1017+
# "utils::kChannelsPacked",
1018+
# ]
1019+
# test_suite.data_gen = "make_seq_tensor"
1020+
# test_suite.dtypes = ["at::kFloat"]
1021+
# return test_suite
10051022

10061023

10071024
def get_reduce_inputs(is_softmax: bool = False):

0 commit comments

Comments
 (0)