Skip to content

Commit

Permalink
Add Tensor::isLocked (flashlight#729)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: flashlight#729

See title - returns whether or not a tensor is memory-locked (a pointer to its device memory is "active"/unlocked via calling `unlock()`).

Reviewed By: benoitsteiner

Differential Revision: D30524152

fbshipit-source-id: 865d42e0261abcaf0d95185752d91adb97a18c91
  • Loading branch information
jacobkahn authored and facebook-github-bot committed Aug 26, 2021
1 parent 66a50de commit 91c8c90
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 0 deletions.
8 changes: 8 additions & 0 deletions flashlight/fl/tensor/TensorAdapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,14 @@ class TensorAdapterBase {
*/
virtual void unlock() = 0;

/**
* Returns true if the tensor has been memory-locked per a call to
* Tensor::device<T>().
*
* @return true if the tensor is locked and a device pointer is active.
*/
virtual bool isLocked() = 0;

/**
* Returns a bool based on Tensor contiguousness in memory.
*/
Expand Down
4 changes: 4 additions & 0 deletions flashlight/fl/tensor/TensorBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ void Tensor::unlock() const {
impl_->unlock();
}

bool Tensor::isLocked() const {
return impl_->isLocked();
}

bool Tensor::isContiguous() const {
return impl_->isContiguous();
}
Expand Down
9 changes: 9 additions & 0 deletions flashlight/fl/tensor/TensorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,15 @@ class Tensor {
*/
void unlock() const;

/**
* Returns true if the tensor has been memory-locked per a call to
* Tensor::device<T>(). After unlocking via Tensor::unlock(), the tensor is no
* longer locked.
*
* @return true if the tensor is locked and a device pointer is active.
*/
bool isLocked() const;

/**
* Returns if the Tensor is contiguous in its memory-based representation.
*
Expand Down
11 changes: 11 additions & 0 deletions flashlight/fl/tensor/backend/af/ArrayFireTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,17 @@ void ArrayFireTensor::unlock() {
AF_CHECK(af_unlock_array(getHandle().get()));
}

bool ArrayFireTensor::isLocked() {
bool res;
auto err = af_is_locked_array(&res, getHandle().get());
if (err != AF_SUCCESS) {
throw std::runtime_error(
"ArrayFireTensor::isLocked - af_is_locked_array returned error: " +
std::to_string(err));
}
return res;
}

bool ArrayFireTensor::isContiguous() {
return af::isLinear(getHandle());
}
Expand Down
1 change: 1 addition & 0 deletions flashlight/fl/tensor/backend/af/ArrayFireTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ class ArrayFireTensor : public TensorAdapterBase {
void device(void** out) override;
void host(void** out) override;
void unlock() override;
bool isLocked() override;
bool isContiguous() override;
Shape strides() override;
Tensor astype(const dtype type) override;
Expand Down
1 change: 1 addition & 0 deletions flashlight/fl/test/tensor/TensorBaseTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ TEST(TensorBaseTest, Metadata) {
Tensor e;
ASSERT_TRUE(e.isEmpty());
ASSERT_FALSE(e.isSparse());
ASSERT_FALSE(e.isLocked());
}

TEST(TensorBaseTest, ostream) {
Expand Down

0 comments on commit 91c8c90

Please sign in to comment.