Skip to content

Commit

Permalink
Implemented clone() and chunk() in TensorBase, for issues OpenMined#31
Browse files Browse the repository at this point in the history
  • Loading branch information
dipanshunagar committed Sep 6, 2017
1 parent 45e64ea commit 3db72da
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
16 changes: 16 additions & 0 deletions syft/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,22 @@ def clamp_(self, minimum=None, maximum=None):
self.data = np.clip(self.data, a_min=minimum, a_max=maximum)
return self

def clone(self):
"""Returns a copy of the tensor. The copy has the same size and data type as the original tensor."""
if self.encrypted:
return NotImplemented
return TensorBase(np.copy(self.data))

def chunk(self, n, dim=0, same_size=False):
"""Returns a list of tensors by splitting the tensor into a number of chunks along a given dimension.
Raises an exception if same_size is set to True and given tensor can't be split in n same-size chunks along dim."""
if self.encrypted:
return NotImplemented
if same_size:
return [TensorBase(x) for x in np.split(self.data, n, dim)]
else:
return [TensorBase(x) for x in np.array_split(self.data, n, dim)]

def bernoulli(self, p):
"""
Returns a Tensor filled with binary random numbers (0 or 1) from a bernoulli distribution
Expand Down
20 changes: 20 additions & 0 deletions tests/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,26 @@ def testClampFloatInPlace(self):
self.assertEqual(t1, expected_tensor)


class cloneTests(unittest.TestCase):
def testClone(self):
t1 = TensorBase(np.random.randint(0, 10, size=(5, 10)))
t2 = t1.clone()
self.assertEqual(t1, t2)
self.assertIsNot(t1, t2)


class chunkTests(unittest.TestCase):
def testChunk(self):
t1 = TensorBase(np.random.randint(0, 10, size=(5, 10)))
t2, t3 = t1.chunk(2, 0)
self.assertNotEqual(t2.shape(), t3.shape())

def testChunkSameSize(self):
t1 = TensorBase(np.random.randint(0, 10, size=(4, 10)))
t2, t3 = t1.chunk(2, 0, same_size=True)
self.assertEqual(t2.shape(), t3.shape())


class bernoulliTests(unittest.TestCase):
def testBernoulli(self):
p = TensorBase(np.random.uniform(size=(3, 2)))
Expand Down

0 comments on commit 3db72da

Please sign in to comment.