Skip to content

Commit

Permalink
Fix CopyVectorToNDArray in src/c_api_common.h (dmlc#3597)
Browse files Browse the repository at this point in the history
* fix CopyVectorToNDArray

* Fix lint

Co-authored-by: Jinjing Zhou <[email protected]>
  • Loading branch information
hirayaku and VoVAllen authored Dec 19, 2021
1 parent a62f2c1 commit 25538ba
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/c_api_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,18 @@ dgl::runtime::PackedFunc ConvertNDArrayVectorToPackedFunc(
const std::vector<dgl::runtime::NDArray>& vec);

/*!
* \brief Copy a vector to an int64_t NDArray.
* \brief Copy a vector to an NDArray.
*
* The element type of the vector must be convertible to int64_t.
* The data type of the NDArray will be IdType, which must be an integer type.
* The element type (DType) of the vector must be convertible to IdType.
*/
template<typename IdType, typename DType>
dgl::runtime::NDArray CopyVectorToNDArray(
const std::vector<DType>& vec) {
using dgl::runtime::NDArray;
const int64_t len = vec.size();
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, sizeof(IdType), 1}, DLContext{kDLCPU, 0});
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, sizeof(IdType) * 8, 1},
DLContext{kDLCPU, 0});
std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data));
return a;
}
Expand Down

0 comments on commit 25538ba

Please sign in to comment.