Skip to content

Commit

Permalink
[Inference]OpenVINO submodule clear (#70313)
Browse files Browse the repository at this point in the history
* test delete openvino thirty party

* optimize openvino submodule

* check
  • Loading branch information
ckl117 authored Dec 19, 2024
1 parent b22c396 commit 0beb71a
Show file tree
Hide file tree
Showing 3 changed files with 161 additions and 106 deletions.
31 changes: 27 additions & 4 deletions cmake/external/openvino.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -115,17 +115,40 @@ ExternalProject_Add(
UPDATE_COMMAND ""
#BUILD_ALWAYS 1
PATCH_COMMAND ${OPENVINO_PATCH_COMMAND}
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${OPENVINO_INSTALL_DIR}
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
-DCMAKE_C_COMPILER=${CMAKE_C_COMPILER}
-DCMAKE_CXX_FLAGS=${ONEDNN_CXXFLAG}
-DCMAKE_CXX_FLAGS_DEBUG=${CMAKE_CXX_FLAGS_DEBUG}
-DCMAKE_CXX_FLAGS_RELEASE=${ONEDNN_CXXFLAG_RELEASE}
-DCMAKE_C_FLAGS=${ONEDNN_CFLAG}
-DCMAKE_C_FLAGS_DEBUG=${CMAKE_C_FLAGS_DEBUG}
-DCMAKE_C_FLAGS_RELEASE=${ONEDNN_CFLAG_RELEASE}
-DCMAKE_INSTALL_PREFIX=${OPENVINO_INSTALL_DIR}
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
-DTHREADING=TBB
-DENABLE_INTEL_CPU=ON
-DENABLE_INTEL_GPU=OFF
-DENABLE_OV_JAX_FRONTEND=OFF
-DENABLE_INTEL_NPU=OFF
-DENABLE_HETERO=OFF
-DENABLE_MULTI=OFF
-DENABLE_AUTO=OFF
-DENABLE_TEMPLATE=OFF
-DENABLE_AUTO_BATCH=OFF
-DENABLE_PROXY=OFF
-DENABLE_OV_ONNX_FRONTEND=OFF
-DENABLE_OV_PYTORCH_FRONTEND=OFF
-DENABLE_OV_TF_FRONTEND=OFF
-DENABLE_OV_TF_LITE_FRONTEND=OFF
-DENABLE_INTEL_NPU=OFF
-DENABLE_OV_PYTORCH_FRONTEND=OFF
-DENABLE_OV_JAX_FRONTEND=OFF
-DENABLE_OV_IR_FRONTEND=OFF
-DENABLE_SAMPLES=OFF
-DENABLE_TESTS=OFF
-DENABLE_PYTHON=OFF
-DENABLE_WHEEL=OFF
-DENABLE_DOCS=OFF
-DENABLE_CPPLINT=OFF
-DENABLE_CLANG_FORMAT=OFF
-DENABLE_NCC_STYLE=OFF
CMAKE_CACHE_ARGS -DCMAKE_INSTALL_PREFIX:PATH=${OPENVINO_INSTALL_DIR}
BUILD_BYPRODUCTS ${BUILD_BYPRODUCTS_ARGS})

Expand Down
33 changes: 33 additions & 0 deletions cmake/third_party.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,39 @@ if(NOT WITH_SETUP_INSTALL)
"Check submodules of paddle, and run 'git submodule sync --recursive && git submodule update --init --recursive'"
)

execute_process(
COMMAND git submodule update --init third_party/openvino
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}
RESULT_VARIABLE result_var)
# List of modules to be deleted
set(delete_module
"thirdparty/zlib/zlib"
"thirdparty/gflags/gflags"
"thirdparty/gtest/gtest"
"thirdparty/ocl/icd_loader"
"thirdparty/ocl/cl_headers"
"thirdparty/ocl/clhpp_headers"
"thirdparty/onnx/onnx"
"src/bindings/python/thirdparty/pybind11"
"thirdparty/ittapi/ittapi"
"cmake/developer_package/ncc_naming_style/ncc"
"src/plugins/intel_gpu/thirdparty/onednn_gpu"
"thirdparty/open_model_zoo"
"thirdparty/json/nlohmann_json"
"thirdparty/flatbuffers/flatbuffers"
"thirdparty/snappy"
"thirdparty/level_zero/level-zero"
"src/plugins/intel_npu/thirdparty/level-zero-ext"
"src/plugins/intel_npu/thirdparty/yaml-cpp")
# Iterate over each module and perform actions
foreach(module IN LISTS delete_module)
# Remove the module from git cache
execute_process(
COMMAND git rm --cached ${module}
WORKING_DIRECTORY ${PADDLE_SOURCE_DIR}/third_party/openvino
RESULT_VARIABLE git_rm_result)
endforeach()

# execute_process does not support sequential commands, so we execute echo command separately
execute_process(
COMMAND git submodule sync --recursive
Expand Down
203 changes: 101 additions & 102 deletions patches/openvino/convert.patch
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,61 @@ index 0000000000..c51a2af6f9
+} // namespace paddle
+} // namespace frontend
+} // namespace ov
diff --git a/src/frontends/paddle/src/op/expand_as_v2.cpp b/src/frontends/paddle/src/op/expand_as_v2.cpp
new file mode 100644
index 0000000000..19cf05758b
--- /dev/null
+++ b/src/frontends/paddle/src/op/expand_as_v2.cpp
@@ -0,0 +1,49 @@
+// Copyright (C) 2018-2024 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "default_opset.hpp"
+#include "openvino/frontend/paddle/node_context.hpp"
+
+namespace ov {
+namespace frontend {
+namespace paddle {
+namespace op {
+NamedOutputs expand_as_v2(const NodeContext& node) {
+ using namespace default_opset;
+ auto x = node.get_input("X");
+ Output<Node> shape_expected_node;
+ if (node.has_input("Y")) {
+ shape_expected_node = std::make_shared<ShapeOf>(node.get_input("Y"), element::i32);
+ } else {
+ std::vector<int32_t> shape_expected;
+ if (node.has_attribute("target_shape")) {
+ shape_expected = node.get_attribute<std::vector<int32_t>>("target_shape");
+ } else {
+ throw std::runtime_error("expand: has no target_shape attribute");
+ }
+ shape_expected_node = Constant::create(element::i32, {shape_expected.size()}, shape_expected);
+ }
+ // expected shape rank
+ const auto shape_expected_node_rank = std::make_shared<ShapeOf>(shape_expected_node, element::i32);
+ // input shape rank
+ const auto input_shape_node_shape = std::make_shared<ShapeOf>(x, element::i32);
+ const auto input_shape_node_rank = std::make_shared<ShapeOf>(input_shape_node_shape, element::i32);
+ // rank difference
+ const auto rank_diff = std::make_shared<Subtract>(shape_expected_node_rank, input_shape_node_rank);
+ // axis index needed to add
+ const auto rank_idx = std::make_shared<Broadcast>(Constant::create(element::i32, {1}, {1}), rank_diff);
+ // add axis
+ const auto fixed_input_shape_node = std::make_shared<Concat>(NodeVector{rank_idx, input_shape_node_shape}, 0);
+
+ // if -1 in shape we will copy the orginal value from input
+ auto zero_node = Constant::create(ov::element::i32, {1}, {0});
+ auto mask_node = std::make_shared<Greater>(shape_expected_node, zero_node);
+ auto fixed_shape_node = std::make_shared<Select>(mask_node, shape_expected_node, fixed_input_shape_node);
+ return node.default_single_output_mapping({std::make_shared<Broadcast>(x, fixed_shape_node)}, {"Out"});
+}
+
+} // namespace op
+} // namespace paddle
+} // namespace frontend
+} // namespace ov
diff --git a/src/frontends/paddle/src/op/expand_v2.cpp b/src/frontends/paddle/src/op/expand_v2.cpp
index d79e49db28..ea174efa3a 100644
--- a/src/frontends/paddle/src/op/expand_v2.cpp
Expand Down Expand Up @@ -270,6 +325,29 @@ index e7b317f288..5ab551dc3b 100644
}

NamedOutputs linear_interp_v2(const NodeContext& node) {
diff --git a/src/frontends/paddle/src/op/less_equal.cpp b/src/frontends/paddle/src/op/less_equal.cpp
new file mode 100644
index 0000000000..89c626c820
--- /dev/null
+++ b/src/frontends/paddle/src/op/less_equal.cpp
@@ -0,0 +1,17 @@
+// Copyright (C) 2018-2024 Intel Corporation
+// SPDX-License-Identifier: Apache-2.0
+//
+
+#include "elementwise_ops.hpp"
+
+namespace ov {
+namespace frontend {
+namespace paddle {
+namespace op {
+NamedOutputs less_equal(const NodeContext& node) {
+ return elementwise_ops<default_opset::LessEqual>(node);
+}
+} // namespace op
+} // namespace paddle
+} // namespace frontend
+} // namespace ov
diff --git a/src/frontends/paddle/src/op/reduce_ops.hpp b/src/frontends/paddle/src/op/reduce_ops.hpp
index 2b59516042..954d1de425 100644
--- a/src/frontends/paddle/src/op/reduce_ops.hpp
Expand Down Expand Up @@ -402,22 +480,31 @@ index 0000000000..64bc75c66a
+} // namespace frontend
+} // namespace ov
diff --git a/src/frontends/paddle/src/op_table.cpp b/src/frontends/paddle/src/op_table.cpp
index 769492eb13..3a6b5dda9f 100644
index 769492eb13..d7de21a481 100644
--- a/src/frontends/paddle/src/op_table.cpp
+++ b/src/frontends/paddle/src/op_table.cpp
@@ -39,9 +39,11 @@ OP_CONVERTER(elementwise_sub);
@@ -39,9 +39,12 @@ OP_CONVERTER(elementwise_sub);
OP_CONVERTER(equal);
OP_CONVERTER(greater_equal);
OP_CONVERTER(not_equal);
+OP_CONVERTER(elu);
OP_CONVERTER(embedding);
OP_CONVERTER(exp);
OP_CONVERTER(expand_v2);
+OP_CONVERTER(expand_as_v2);
+OP_CONVERTER(eye);
OP_CONVERTER(flip);
OP_CONVERTER(flatten_contiguous_range);
OP_CONVERTER(floor);
@@ -138,6 +140,12 @@ OP_CONVERTER(write_to_array);
@@ -60,6 +63,7 @@ OP_CONVERTER(index_select);
OP_CONVERTER(layer_norm);
OP_CONVERTER(leaky_relu);
OP_CONVERTER(less_than);
+OP_CONVERTER(less_equal);
OP_CONVERTER(linear_interp_v2);
OP_CONVERTER(linspace);
OP_CONVERTER(lod_array_length);
@@ -138,6 +142,12 @@ OP_CONVERTER(write_to_array);
OP_CONVERTER(where_index);
OP_CONVERTER(yolo_box);
OP_CONVERTER(generate_proposals_v2);
Expand All @@ -430,19 +517,28 @@ index 769492eb13..3a6b5dda9f 100644
} // namespace op
std::map<std::string, CreatorFunction> get_supported_ops() {
return {{"arg_max", op::argmax},
@@ -173,9 +181,11 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
@@ -173,9 +183,12 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"elementwise_sub", op::elementwise_sub},
{"dropout", op::dropout},
{"elementwise_pow", op::elementwise_pow},
+ {"elu", op::elu},
{"equal", op::equal},
{"exp", op::exp},
{"expand_v2", op::expand_v2},
+ {"expand_as_v2", op::expand_as_v2},
+ {"eye", op::eye},
{"fill_any_like", op::fill_any_like},
{"fill_constant", op::fill_constant},
{"fill_constant_batch_size_like", op::fill_constant_batch_size_like},
@@ -277,7 +287,13 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
@@ -196,6 +209,7 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"layer_norm", op::layer_norm},
{"leaky_relu", op::leaky_relu},
{"less_than", op::less_than},
+ {"less_equal", op::less_equal},
{"linear_interp_v2", op::linear_interp_v2},
{"linspace", op::linspace},
{"lod_array_length", op::lod_array_length},
@@ -277,7 +291,13 @@ std::map<std::string, CreatorFunction> get_supported_ops() {
{"while", op::while_},
{"write_to_array", op::write_to_array},
{"where_index", op::where_index},
Expand Down Expand Up @@ -479,100 +575,3 @@ index 99357a3a33..53ea785260 100644
std::string("flip_1/flip_1.pdmodel"),
std::string("flip_2/flip_2.pdmodel"),
std::string("flip_3/flip_3.pdmodel"),
diff --git a/src/frontends/paddle/tests/test_models/gen_scripts/generate_elu.py b/src/frontends/paddle/tests/test_models/gen_scripts/generate_elu.py
new file mode 100644
index 0000000000..4dc67b2051
--- /dev/null
+++ b/src/frontends/paddle/tests/test_models/gen_scripts/generate_elu.py
@@ -0,0 +1,44 @@
+# Copyright (C) 2018-2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+#
+# relu6 paddle model generator
+#
+import numpy as np
+from save_model import saveModel
+import paddle
+import sys
+
+
+def elu(name: str, x, alpha=None, data_type='float32'):
+ paddle.enable_static()
+
+ with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
+ node_x = paddle.static.data(name='x', shape=x.shape, dtype=data_type)
+
+ if paddle.__version__ >= '2.0.0':
+ out = paddle.nn.functional.elu(node_x, alpha, name='elu')
+ else:
+ out = paddle.fluid.layers.elu(node_x, alpha, name='elu')
+ cpu = paddle.static.cpu_places(1)
+ exe = paddle.static.Executor(cpu[0])
+ # startup program will call initializer to initialize the parameters.
+ exe.run(paddle.static.default_startup_program())
+
+ outs = exe.run(
+ feed={'x': x},
+ fetch_list=[out])
+
+ saveModel(name, exe, feed_vars=[node_x], fetchlist=[out],
+ inputs=[x], outputs=[outs[0]], target_dir=sys.argv[1])
+
+ return outs[0]
+
+
+def main():
+ data_type = 'float32'
+ data = np.random.randn(2, 3, 4).astype('float32')
+ elu("elu", data)
+
+if __name__ == "__main__":
+ main()
diff --git a/src/frontends/paddle/tests/test_models/gen_scripts/generate_eye.py b/src/frontends/paddle/tests/test_models/gen_scripts/generate_eye.py
new file mode 100644
index 0000000000..9b1a4f668c
--- /dev/null
+++ b/src/frontends/paddle/tests/test_models/gen_scripts/generate_eye.py
@@ -0,0 +1,41 @@
+# Copyright (C) 2018-2024 Intel Corporation
+# SPDX-License-Identifier: Apache-2.0
+
+#
+# fill_const paddle model generator
+#
+import numpy as np
+from save_model import saveModel
+import paddle
+import sys
+
+
+def eye(name : str, rows, cols = None, dtype = None):
+ paddle.enable_static()
+ with paddle.static.program_guard(paddle.static.Program(), paddle.static.Program()):
+ if paddle.__version__ >= '2.0.0':
+ x1 = paddle.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill')
+ x2 = paddle.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill')
+ else:
+ x1 = paddle.fluid.layers.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill_constant')
+ x2 = paddle.fluid.layers.eye(num_rows=rows, num_columns=cols, dtype=dtype, name='fill_constant')
+ out = paddle.add(x1, x2)
+ cpu = paddle.static.cpu_places(1)
+ exe = paddle.static.Executor(cpu[0])
+ # startup program will call initializer to initialize the parameters.
+ exe.run(paddle.static.default_startup_program())
+
+ outs = exe.run(
+ fetch_list=[out])
+
+ saveModel(name, exe, feed_vars=[], fetchlist=[out], inputs=[], outputs=[outs[0]], target_dir=sys.argv[1])
+
+ return outs[0]
+
+def main():
+ eye("eye", 3)
+ eye("eye_int32", 2, 3, "int32")
+ eye("eye_int64", 2, 3, "int64")
+
+if __name__ == "__main__":
+ main()

0 comments on commit 0beb71a

Please sign in to comment.