diff --git a/dali/test/dali_operator_test.h b/dali/test/dali_operator_test.h index 01f96c7bff7..f0b22c240b9 100644 --- a/dali/test/dali_operator_test.h +++ b/dali/test/dali_operator_test.h @@ -1,4 +1,4 @@ -// Copyright (c) 2018-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -123,27 +123,20 @@ inline void AddOperatorToPipeline(Pipeline &pipeline, const OpSpec &op_spec) { pipeline.AddOperator(op_spec); } - -inline Workspace CreateWorkspace() { - Workspace ws; - return ws; -} - - inline void RunPipeline(Pipeline &pipeline) { pipeline.Run(); } inline std::vector -GetOutputsFromPipeline(Pipeline &pipeline, const std::string &output_backend) { +GetOutputsFromPipeline(Workspace &ws, Pipeline &pipeline, const std::string &output_backend) { std::vector ret; - auto workspace = CreateWorkspace(); - pipeline.Outputs(&workspace); - for (int output_idx = 0; output_idx < workspace.NumOutput(); output_idx++) { - if (workspace.OutputIsType(output_idx)) { - ret.emplace_back(&workspace.Output(output_idx)); + ws = {}; + pipeline.Outputs(&ws); + for (int output_idx = 0; output_idx < ws.NumOutput(); output_idx++) { + if (ws.OutputIsType(output_idx)) { + ret.emplace_back(&ws.Output(output_idx)); } else { - ret.emplace_back(&workspace.Output(output_idx)); + ret.emplace_back(&ws.Output(output_idx)); } } return ret; @@ -284,10 +277,11 @@ class DaliOperatorTest : public ::testing::Test, public ::testing::WithParamInte BuildPipeline(pipeline, op_spec); SetInputInPipeline(pipeline, input); RunPipeline(pipeline); - return GetOutputsFromPipeline(pipeline, output_backend); + return GetOutputsFromPipeline(ws_, pipeline, output_backend); } size_t num_threads_ = 1; + Workspace ws_; }; } // namespace testing