From a0b895c08d3a9312ac1f2fde55f9d99a3a1e6b5d Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Fri, 26 Nov 2021 18:02:52 +0800 Subject: [PATCH] [Pten]Support parse kernel key by multi-inputs (#37517) * Support parse kernel key by multi-inputs * optimize code according to reviewer --- paddle/pten/api/lib/data_type_set.h | 85 +++++++++++++++++++++++++++ paddle/pten/api/lib/kernel_dispatch.h | 9 +++ paddle/pten/api/lib/math.cc | 8 +-- 3 files changed, 98 insertions(+), 4 deletions(-) create mode 100644 paddle/pten/api/lib/data_type_set.h diff --git a/paddle/pten/api/lib/data_type_set.h b/paddle/pten/api/lib/data_type_set.h new file mode 100644 index 00000000000000..cbb95e8eed0c78 --- /dev/null +++ b/paddle/pten/api/lib/data_type_set.h @@ -0,0 +1,85 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/pten/api/ext/exception.h" +#include "paddle/pten/common/data_type.h" +namespace paddle { +namespace experimental { + +/* This class is used to store DataType in a bit set*/ +class DataTypeSet final { + public: + constexpr DataTypeSet() : bitset_(0) {} + explicit constexpr DataTypeSet(DataType dtype) + : bitset_(dtype == DataType::UNDEFINED + ? 0 + : 1ULL << (static_cast(dtype) - 1)) {} + + uint64_t bitset() const { return bitset_; } + + bool inline Has(DataType dtype) const { + PD_CHECK(dtype != DataType::UNDEFINED, + "Data type argument can't be UNDEFINED."); + return static_cast(bitset_ & DataTypeSet(dtype).bitset()); + } + bool IsEmpty() const { return bitset_ == 0; } + + DataTypeSet operator|(const DataTypeSet& other) const { + return DataTypeSet(bitset_ | other.bitset()); + } + DataTypeSet operator&(const DataTypeSet& other) const { + return DataTypeSet(bitset_ & other.bitset()); + } + DataTypeSet operator-(const DataTypeSet& other) const { + return DataTypeSet(bitset_ & ~other.bitset()); + } + DataTypeSet operator^(const DataTypeSet& other) const { + return DataTypeSet(bitset_ ^ other.bitset()); + } + + bool operator==(const DataTypeSet& other) const { + return bitset_ == other.bitset(); + } + + private: + constexpr DataTypeSet(uint64_t bitset) : bitset_(bitset) {} + uint64_t bitset_; +}; + +// Now only supports promotion of complex type +inline DataType PromoteTypes(const DataTypeSet& dtype_set) { + constexpr auto f8 = 1ULL << (static_cast(DataType::FLOAT64) - 1); + constexpr auto c4 = 1ULL << (static_cast(DataType::COMPLEX64) - 1); + constexpr auto c8 = 1ULL << (static_cast(DataType::COMPLEX128) - 1); + DataType promote_type = DataType::UNDEFINED; + + // kernel dtype need promote when meet input dtype with more precision + if ((dtype_set.bitset() & c8) == c8) { + promote_type = DataType::COMPLEX128; + } else if ((dtype_set.bitset() & c4) == c4) { + if ((dtype_set.bitset() & f8) == f8) { + promote_type = DataType::COMPLEX128; + } else { + promote_type = DataType::COMPLEX64; + } + } + return promote_type; +} + +} // namespace experimental +} // namespace paddle diff --git a/paddle/pten/api/lib/kernel_dispatch.h b/paddle/pten/api/lib/kernel_dispatch.h index 46aa9ce992939a..2dba88d07eb127 100644 --- a/paddle/pten/api/lib/kernel_dispatch.h +++ b/paddle/pten/api/lib/kernel_dispatch.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/pten/api/include/tensor.h" #include "paddle/pten/api/lib/backend_set.h" +#include "paddle/pten/api/lib/data_type_set.h" #include "paddle/pten/common/data_type.h" #include "paddle/pten/common/layout.h" @@ -88,6 +89,9 @@ struct ArgsIterator { struct KernelKeyParser : ArgsIterator { KernelKeySet key_set; + // this dtype_set is used for cache multi-inputs dtype and used for + // data_promote + DataTypeSet dtype_set{DataType::UNDEFINED}; // TODO(chenweihang): deal with multiple diff input Tensors // TODO(chenweihang): add global device guard method to set backend @@ -96,6 +100,11 @@ struct KernelKeyParser : ArgsIterator { // TODO(chenweihang): selecte multi layout and dtype key_set.layout = x.layout(); key_set.dtype = x.type(); + dtype_set = dtype_set | DataTypeSet(x.dtype()); + auto promote_result = PromoteTypes(dtype_set); + if (promote_result != DataType::UNDEFINED) { + key_set.dtype = promote_result; + } } void operator()(const std::vector& x) { diff --git a/paddle/pten/api/lib/math.cc b/paddle/pten/api/lib/math.cc index 30d6ab2aa8a8a9..b030c60750c873 100644 --- a/paddle/pten/api/lib/math.cc +++ b/paddle/pten/api/lib/math.cc @@ -70,7 +70,7 @@ PD_DLL_DECL Tensor mean(const Tensor& x) { PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) { // 1. Get kernel signature and kernel - auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key_set = ParseKernelKeyByInputArgs(x, y); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( "elementwise_add", kernel_key); @@ -105,7 +105,7 @@ PD_DLL_DECL Tensor add(const Tensor& x, const Tensor& y) { PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) { // 1. Get kernel signature and kernel - auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key_set = ParseKernelKeyByInputArgs(x, y); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( "elementwise_sub", kernel_key); @@ -140,7 +140,7 @@ PD_DLL_DECL Tensor subtract(const Tensor& x, const Tensor& y) { PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) { // 1. Get kernel signature and kernel - auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key_set = ParseKernelKeyByInputArgs(x, y); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( "elementwise_div", kernel_key); @@ -175,7 +175,7 @@ PD_DLL_DECL Tensor divide(const Tensor& x, const Tensor& y) { PD_DLL_DECL Tensor multiply(const Tensor& x, const Tensor& y) { // 1. Get kernel signature and kernel - auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key_set = ParseKernelKeyByInputArgs(x, y); auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( "elementwise_mul", kernel_key);