Skip to content

Commit

Permalink
Fix multi-function compute (flutter#39603)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnfield authored Feb 14, 2023
1 parent f7dfb2b commit c2e165e
Show file tree
Hide file tree
Showing 5 changed files with 194 additions and 17 deletions.
10 changes: 9 additions & 1 deletion impeller/fixtures/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ impeller_shaders("shader_fixtures") {
"mipmaps.frag",
"mipmaps.vert",
"sample.comp",
"stage1.comp",
"stage2.comp",
"simple.vert",
"test_texture.frag",
"test_texture.vert",
]

if (impeller_enable_opengles) {
gles_exclusions = [ "sample.comp" ]
gles_exclusions = [
"sample.comp",
"stage1.comp",
"stage2.comp",
]
}
}

Expand Down Expand Up @@ -77,6 +83,8 @@ test_fixtures("file_fixtures") {
"sample_with_binding.vert",
"simple.vert.hlsl",
"sa%m#ple.vert",
"stage1.comp",
"stage2.comp",
"struct_def_bug.vert",
"table_mountain_nx.png",
"table_mountain_ny.png",
Expand Down
29 changes: 29 additions & 0 deletions impeller/fixtures/stage1.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
layout(local_size_x = 128) in;
layout(std430) buffer;

layout(binding = 0) writeonly buffer Output {
uint count;
uint elements[];
}
output_data;

layout(binding = 1) readonly buffer Input {
uint count;
uint elements[];
}
input_data;

void main() {
uint ident = gl_GlobalInvocationID.x;

if (ident >= input_data.count) {
return;
}

uint out_slot = ident * 2;

output_data.count = input_data.count * 2;

output_data.elements[out_slot] = input_data.elements[ident] * 2;
output_data.elements[out_slot + 1] = input_data.elements[ident] * 3;
}
26 changes: 26 additions & 0 deletions impeller/fixtures/stage2.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
layout(local_size_x = 128) in;
layout(std430) buffer;

layout(binding = 0) writeonly buffer Output {
uint count;
uint elements[];
}
output_data;

layout(binding = 1) readonly buffer Input {
uint count;
uint elements[];
}
input_data;

void main() {
uint ident = gl_GlobalInvocationID.x;

if (ident >= input_data.count) {
return;
}

output_data.count = input_data.count;

output_data.elements[ident] = input_data.elements[ident] * 2;
}
32 changes: 16 additions & 16 deletions impeller/renderer/backend/metal/compute_pass_mtl.mm
Original file line number Diff line number Diff line change
Expand Up @@ -241,23 +241,23 @@ static bool Bind(ComputePassBindingsCache& pass,
return false;
}
}
// TODO(dnfield): use feature detection to support non-uniform threadgroup
// sizes.
// https://github.com/flutter/flutter/issues/110619

// For now, check that the sizes are uniform.
FML_DCHECK(grid_size == thread_group_size);
auto width = grid_size.width;
auto height = grid_size.height;
while (width * height >
static_cast<int64_t>(
pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup)) {
width /= 2;
height /= 2;
}
auto size = MTLSizeMake(width, height, 1);
[encoder dispatchThreadgroups:size threadsPerThreadgroup:size];
}
// TODO(dnfield): use feature detection to support non-uniform threadgroup
// sizes.
// https://github.com/flutter/flutter/issues/110619

// For now, check that the sizes are uniform.
FML_DCHECK(grid_size == thread_group_size);
auto width = grid_size.width;
auto height = grid_size.height;
while (width * height >
static_cast<int64_t>(
pass_bindings.GetPipeline().maxTotalThreadsPerThreadgroup)) {
width /= 2;
height /= 2;
}
auto size = MTLSizeMake(width, height, 1);
[encoder dispatchThreadgroups:size threadsPerThreadgroup:size];

return true;
}
Expand Down
114 changes: 114 additions & 0 deletions impeller/renderer/compute_unittests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
#include "flutter/fml/synchronization/waitable_event.h"
#include "flutter/fml/time/time_point.h"
#include "flutter/testing/testing.h"
#include "gmock/gmock.h"
#include "impeller/base/strings.h"
#include "impeller/fixtures/sample.comp.h"
#include "impeller/fixtures/stage1.comp.h"
#include "impeller/fixtures/stage2.comp.h"
#include "impeller/playground/compute_playground_test.h"
#include "impeller/renderer/command_buffer.h"
#include "impeller/renderer/compute_command.h"
Expand Down Expand Up @@ -102,5 +105,116 @@ TEST_P(ComputeTest, CanCreateComputePass) {
latch.Wait();
}

TEST_P(ComputeTest, MultiStageInputAndOutput) {
using CS1 = Stage1ComputeShader;
using Stage1PipelineBuilder = ComputePipelineBuilder<CS1>;
using CS2 = Stage2ComputeShader;
using Stage2PipelineBuilder = ComputePipelineBuilder<CS2>;

auto context = GetContext();
ASSERT_TRUE(context);

auto pipeline_desc_1 =
Stage1PipelineBuilder::MakeDefaultPipelineDescriptor(*context);
ASSERT_TRUE(pipeline_desc_1.has_value());
auto compute_pipeline_1 =
context->GetPipelineLibrary()->GetPipeline(pipeline_desc_1).Get();
ASSERT_TRUE(compute_pipeline_1);

auto pipeline_desc_2 =
Stage2PipelineBuilder::MakeDefaultPipelineDescriptor(*context);
ASSERT_TRUE(pipeline_desc_2.has_value());
auto compute_pipeline_2 =
context->GetPipelineLibrary()->GetPipeline(pipeline_desc_2).Get();
ASSERT_TRUE(compute_pipeline_2);

auto cmd_buffer = context->CreateCommandBuffer();
auto pass = cmd_buffer->CreateComputePass();
ASSERT_TRUE(pass && pass->IsValid());

static constexpr size_t kCount1 = 5;
static constexpr size_t kCount2 = kCount1 * 2;

pass->SetGridSize(ISize(512, 1));
pass->SetThreadGroupSize(ISize(512, 1));

CS1::Input<kCount1> input_1;
input_1.count = kCount1;
for (uint i = 0; i < kCount1; i++) {
input_1.elements[i] = i;
}

CS2::Input<kCount2> input_2;
input_2.count = kCount2;
for (uint i = 0; i < kCount2; i++) {
input_2.elements[i] = i;
}

DeviceBufferDescriptor output_desc_1;
output_desc_1.storage_mode = StorageMode::kHostVisible;
output_desc_1.size = sizeof(CS1::Output<kCount2>);

auto output_buffer_1 =
context->GetResourceAllocator()->CreateBuffer(output_desc_1);
output_buffer_1->SetLabel("Output Buffer Stage 1");

DeviceBufferDescriptor output_desc_2;
output_desc_2.storage_mode = StorageMode::kHostVisible;
output_desc_2.size = sizeof(CS2::Output<kCount2>);

auto output_buffer_2 =
context->GetResourceAllocator()->CreateBuffer(output_desc_2);
output_buffer_2->SetLabel("Output Buffer Stage 2");

{
ComputeCommand cmd;
cmd.label = "Compute1";
cmd.pipeline = compute_pipeline_1;

CS1::BindInput(cmd,
pass->GetTransientsBuffer().EmplaceStorageBuffer(input_1));
CS1::BindOutput(cmd, output_buffer_1->AsBufferView());

ASSERT_TRUE(pass->AddCommand(std::move(cmd)));
}

{
ComputeCommand cmd;
cmd.label = "Compute2";
cmd.pipeline = compute_pipeline_2;

CS1::BindInput(cmd, output_buffer_1->AsBufferView());
CS2::BindOutput(cmd, output_buffer_2->AsBufferView());
ASSERT_TRUE(pass->AddCommand(std::move(cmd)));
}

ASSERT_TRUE(pass->EncodeCommands());

fml::AutoResetWaitableEvent latch;
ASSERT_TRUE(cmd_buffer->SubmitCommands([&latch, &output_buffer_1,
&output_buffer_2](
CommandBuffer::Status status) {
EXPECT_EQ(status, CommandBuffer::Status::kCompleted);

CS1::Output<kCount2>* output_1 = reinterpret_cast<CS1::Output<kCount2>*>(
output_buffer_1->AsBufferView().contents);
EXPECT_TRUE(output_1);
EXPECT_EQ(output_1->count, 10u);
EXPECT_THAT(output_1->elements,
::testing::ElementsAre(0, 0, 2, 3, 4, 6, 6, 9, 8, 12));

CS2::Output<kCount2>* output_2 = reinterpret_cast<CS2::Output<kCount2>*>(
output_buffer_2->AsBufferView().contents);
EXPECT_TRUE(output_2);
EXPECT_EQ(output_2->count, 10u);
EXPECT_THAT(output_2->elements,
::testing::ElementsAre(0, 0, 4, 6, 8, 12, 12, 18, 16, 24));

latch.Signal();
}));

latch.Wait();
}

} // namespace testing
} // namespace impeller

0 comments on commit c2e165e

Please sign in to comment.