Skip to content

Commit

Permalink
[lite] Add util for copying TfLiteTensor.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 407355119
Change-Id: I7c0378f0439608c04f7fd6cb6dbb5a1c5b5f9efb
  • Loading branch information
karimnosseir authored and tensorflower-gardener committed Nov 3, 2021
1 parent cf60b19 commit a527a97
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow/lite/c/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ cc_test(
size = "small",
srcs = ["common_test.cc"],
deps = [
":c_api_types",
":common",
"@com_google_googletest//:gtest_main",
],
Expand Down
18 changes: 18 additions & 0 deletions tensorflow/lite/c/common.c
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,24 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
tensor->quantization.params = NULL;
}

TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst) {
if (!src || !dst)
return kTfLiteOk;
if (src->bytes != dst->bytes)
return kTfLiteError;

dst->type = src->type;
if (dst->dims)
TfLiteIntArrayFree(dst->dims);
dst->dims = TfLiteIntArrayCopy(src->dims);
memcpy(dst->data.raw, src->data.raw, src->bytes);
dst->buffer_handle = src->buffer_handle;
dst->data_is_stale = src->data_is_stale;
dst->delegate = src->delegate;

return kTfLiteOk;
}

void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor) {
if (tensor->allocation_type != kTfLiteDynamic &&
tensor->allocation_type != kTfLitePersistentRo) {
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/lite/c/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,16 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);

// Copies the contents of 'src' in 'dst'.
// Function does nothing if either 'src' or 'dst' is passed as nullptr and
// return kTfLiteOk.
// Returns kTfLiteError if 'src' and 'dst' doesn't have matching data size.
// Note function copies contents, so it won't create new data pointer
// or change allocation type.
// All Tensor related properties will be copied from 'src' to 'dst' like
// quantization, sparsity, ...
TfLiteStatus TfLiteTensorCopy(const TfLiteTensor* src, TfLiteTensor* dst);

// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
Expand Down
59 changes: 59 additions & 0 deletions tensorflow/lite/c/common_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/lite/c/common.h"

#include <gtest/gtest.h>
#include "tensorflow/lite/c/c_api_types.h"

namespace tflite {

Expand Down Expand Up @@ -138,4 +139,62 @@ TEST(Sparsity, TestSparsityFree) {
TfLiteTensorFree(&t);
}

TEST(TensorCopy, TensorCopy_VALID) {
const int kNumElements = 32;
const int kBytes = sizeof(float) * kNumElements;
TfLiteTensor src;
TfLiteTensor dst;
TfLiteDelegate delegate;
memset(&delegate, 0, sizeof(delegate));
memset(&src, 0, sizeof(TfLiteTensor));
memset(&dst, 0, sizeof(TfLiteTensor));
src.data.raw = static_cast<char*>(malloc(kBytes));
for (int i = 0; i < kNumElements; ++i) {
src.data.f[i] = i;
}
dst.data.raw = static_cast<char*>(malloc(kBytes));

src.bytes = dst.bytes = kBytes;
src.delegate = &delegate;
src.data_is_stale = true;
src.allocation_type = kTfLiteDynamic;
src.type = kTfLiteFloat32;
src.dims = TfLiteIntArrayCreate(1);
src.dims->data[0] = 1;
src.dims_signature = TfLiteIntArrayCopy(src.dims);
src.buffer_handle = 5;

EXPECT_EQ(kTfLiteOk, TfLiteTensorCopy(&src, &dst));

EXPECT_EQ(dst.bytes, src.bytes);
EXPECT_EQ(dst.delegate, src.delegate);
EXPECT_EQ(dst.data_is_stale, src.data_is_stale);
EXPECT_EQ(dst.type, src.type);
EXPECT_EQ(1, TfLiteIntArrayEqual(dst.dims, src.dims));
EXPECT_EQ(dst.buffer_handle, src.buffer_handle);
for (int i = 0; i < kNumElements; ++i) {
EXPECT_EQ(dst.data.f[i], src.data.f[i]);
}

TfLiteTensorFree(&src);
// We don't change allocation type, and since the test keeps the dst
// allocation as non dynamic, then we have to delete it manually.
free(dst.data.raw);
TfLiteTensorFree(&dst);
}

TEST(TensorCopy, TensorCopy_INVALID) {
TfLiteTensor src;
TfLiteTensor dst;

// Nullptr passed, should just return.
EXPECT_EQ(kTfLiteOk, TfLiteTensorCopy(&src, nullptr));
EXPECT_EQ(kTfLiteOk, TfLiteTensorCopy(nullptr, &dst));

// Incompatible sizes passed.
src.bytes = 10;
dst.bytes = 12;
EXPECT_EQ(kTfLiteError, TfLiteTensorCopy(&src, &dst));
}

} // namespace tflite

0 comments on commit a527a97

Please sign in to comment.