Skip to content

Commit

Permalink
Improved datagetters for situations where you just want one batch of …
Browse files Browse the repository at this point in the history
…data

Summary: random data now will not need to be squeezed if you just want one batch or one task for binary.

Reviewed By: ananthsub

Differential Revision: D45033694

fbshipit-source-id: 991bab03bd9392a1ad3b4cd4b549a0ec7b80769e
  • Loading branch information
bobakfb authored and facebook-github-bot committed Apr 20, 2023
1 parent 59e0cad commit 30c6954
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_with_randomized_data_getter(self) -> None:
num_updates = 1

for _ in range(100):
input, target, threshold = rd.get_rand_inputs_binned_binary(
input, target, threshold = rd.get_rand_data_binned_binary(
num_updates, num_tasks, batch_size, num_bins
)
input = input.reshape(shape=(num_tasks, batch_size))
Expand Down
100 changes: 90 additions & 10 deletions tests/utils/test_random_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,36 @@

import torch

from torcheval.utils import get_rand_data_binary, get_rand_data_multiclass
from torcheval.utils import (
get_rand_data_binary,
get_rand_data_binned_binary,
get_rand_data_multiclass,
)


class RandomDataTest(unittest.TestCase):
class BinaryRandomDataTest(unittest.TestCase):
cuda_avail: bool = torch.cuda.is_available()

def test_get_rand_data_binary(self) -> None:
def test_get_rand_data_binary_shapes(self) -> None:
# multi update/multi-task
input, targets = get_rand_data_binary(num_updates=2, num_tasks=5, batch_size=10)
self.assertEqual(input.size(), targets.size())
self.assertEqual(input.size(), torch.Size([2, 5, 10]))
self.assertEqual(targets.size(), torch.Size([2, 5, 10]))

def test_get_rand_data_multiclass(self) -> None:
input, targets = get_rand_data_multiclass(
num_updates=2, num_classes=5, batch_size=10
)
self.assertEqual(input.size(), torch.Size([2, 10, 5]))
self.assertTrue(torch.all(torch.lt(targets, 5)))
# single update/multi-task
input, targets = get_rand_data_binary(num_updates=1, num_tasks=5, batch_size=10)
self.assertEqual(input.size(), torch.Size([5, 10]))
self.assertEqual(targets.size(), torch.Size([5, 10]))

# single update/single-task
input, targets = get_rand_data_binary(num_updates=1, num_tasks=1, batch_size=10)
self.assertEqual(input.size(), torch.Size([10]))
self.assertEqual(targets.size(), torch.Size([10]))

# multi update/single-task
input, targets = get_rand_data_binary(num_updates=3, num_tasks=1, batch_size=10)
self.assertEqual(input.size(), torch.Size([3, 10]))
self.assertEqual(targets.size(), torch.Size([3, 10]))

@unittest.skipUnless(
condition=cuda_avail, reason="This test needs a GPU host to run."
Expand All @@ -38,6 +52,72 @@ def test_get_rand_data_binary_GPU(self) -> None:
self.assertTrue(input.is_cuda)
self.assertTrue(targets.is_cuda)

def test_get_rand_data_binned_binary_shapes(self) -> None:
# multi update/multi-task
input, targets, thresholds = get_rand_data_binned_binary(
num_updates=2, num_tasks=5, batch_size=10, num_bins=20
)
self.assertEqual(input.size(), torch.Size([2, 5, 10]))
self.assertEqual(targets.size(), torch.Size([2, 5, 10]))
self.assertEqual(thresholds.size(), torch.Size([20]))

# single update/multi-task
input, targets, thresholds = get_rand_data_binned_binary(
num_updates=1, num_tasks=5, batch_size=10, num_bins=20
)
self.assertEqual(input.size(), torch.Size([5, 10]))
self.assertEqual(targets.size(), torch.Size([5, 10]))
self.assertEqual(thresholds.size(), torch.Size([20]))

# single update/single-task
input, targets, thresholds = get_rand_data_binned_binary(
num_updates=1, num_tasks=1, batch_size=10, num_bins=20
)
self.assertEqual(input.size(), torch.Size([10]))
self.assertEqual(targets.size(), torch.Size([10]))
self.assertEqual(thresholds.size(), torch.Size([20]))

# multi update/single-task
input, targets, thresholds = get_rand_data_binned_binary(
num_updates=3, num_tasks=1, batch_size=10, num_bins=20
)
self.assertEqual(input.size(), torch.Size([3, 10]))
self.assertEqual(targets.size(), torch.Size([3, 10]))
self.assertEqual(thresholds.size(), torch.Size([20]))

@unittest.skipUnless(
condition=cuda_avail, reason="This test needs a GPU host to run."
)
def test_get_rand_data_binned_binary_GPU(self) -> None:
device = torch.device("cuda")
input, targets, thresholds = get_rand_data_binned_binary(
num_updates=2, num_tasks=5, batch_size=10, num_bins=20, device=device
)
self.assertTrue(input.is_cuda)
self.assertTrue(targets.is_cuda)
self.assertTrue(thresholds.is_cuda)


class MulticlassRandomDataTest(unittest.TestCase):
cuda_avail: bool = torch.cuda.is_available()

def test_get_rand_data_multiclass_shapes(self) -> None:
# multi update
input, targets = get_rand_data_multiclass(
num_updates=2, num_classes=5, batch_size=10
)
self.assertEqual(input.size(), torch.Size([2, 10, 5]))
self.assertEqual(targets.size(), torch.Size([2, 10]))
self.assertTrue(torch.all(torch.lt(targets, 5)))

# single update
input, targets = get_rand_data_multiclass(
num_updates=1, num_classes=5, batch_size=10
)
self.assertEqual(input.size(), torch.Size([10, 5]))
self.assertEqual(targets.size(), torch.Size([10]))
self.assertTrue(torch.all(torch.lt(targets, 5)))

@unittest.skipUnless(
condition=cuda_avail, reason="This test needs a GPU host to run."
)
Expand Down
7 changes: 6 additions & 1 deletion torcheval/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,14 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torcheval.utils.random_data import get_rand_data_binary, get_rand_data_multiclass
from torcheval.utils.random_data import (
get_rand_data_binary,
get_rand_data_binned_binary,
get_rand_data_multiclass,
)

__all__ = [
"get_rand_data_binary",
"get_rand_data_binned_binary",
"get_rand_data_multiclass",
]
69 changes: 54 additions & 15 deletions torcheval/utils/random_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,37 @@ def get_rand_data_binary(
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Generates a random binary dataset.
Generates a random binary dataset. For the returned tensors, shape[0] will correspond to the update, shape[1] will correspond to the task, and shape[2] will correspond to the sample.
Notes:
- If num_tasks is 1 the task dimension will be omitted; tensors will have shape (num_updates, batch_size) or (batch_size, ) depending on whether num_updates=1.
- If num_updates is 1, the update dimension will be omitted; tensors will have shape (num_tasks, batch_size) or (batch_size, ) depending on whether num_tasks=1.
- If both num_updates and num_tasks are not 1, the returned tensors will have shape (num_updates, num_tasks, batch_size).
Args:
num_updates: the number of calls to update on each rank.
num_tasks: the number of tasks for the metric.
batch_size: batch size of the dataset.
device: device for the returned Tensors
Returns:
torch.Tensor: random feature data
torch.Tensor: random targets
"""
if device is None:
device = torch.device("cpu")
input = torch.rand(size=[num_updates, num_tasks, batch_size]).to(device)
targets = torch.randint(
low=0, high=2, size=[num_updates, num_tasks, batch_size]
).to(device)
return input, targets

shape = [num_updates, num_tasks, batch_size]
if num_tasks == 1 and num_updates == 1:
shape = [batch_size]
elif num_updates == 1:
shape = [num_tasks, batch_size]
elif num_tasks == 1:
shape = [num_updates, batch_size]

input = torch.rand(size=shape)
targets = torch.randint(low=0, high=2, size=shape)
return input.to(device), targets.to(device)


def get_rand_data_multiclass(
Expand All @@ -45,6 +58,9 @@ def get_rand_data_multiclass(
"""
Generates a random multiclass dataset.
Notes:
- If num_updates is 1, the update dimension will be omitted; input tensors will have shape (batch_size, num_classes) and target tensor will have shape (batch_size, ).
Args:
num_updates: the number of calls to update on each rank.
num_classes: the number of classes for the dataset.
Expand All @@ -56,31 +72,54 @@ def get_rand_data_multiclass(
"""
if device is None:
device = torch.device("cpu")
input = torch.rand(size=[num_updates, batch_size, num_classes]).to(device)
targets = torch.randint(low=0, high=num_classes, size=[num_updates, batch_size]).to(
device
)
return input, targets

input_shape = [num_updates, batch_size, num_classes]
targets_shape = [num_updates, batch_size]
if num_updates == 1:
input_shape = [batch_size, num_classes]
targets_shape = [batch_size]

def get_rand_inputs_binned_binary(
num_updates: int, num_tasks: int, batch_size: int, num_bins: int
input = torch.rand(size=input_shape)
targets = torch.randint(low=0, high=num_classes, size=targets_shape)
return input.to(device), targets.to(device)


def get_rand_data_binned_binary(
num_updates: int,
num_tasks: int,
batch_size: int,
num_bins: int,
device: Optional[torch.device] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get random binary dataset, along with a threshold for binned data.
Notes:
- If num_tasks is 1 the task dimension will be omitted; input and target tensors will have shape (num_updates, batch_size) or (batch_size, ) depending on whether num_updates=1.
- If num_updates is 1, the update dimension will be omitted; input and target tensors will have shape (num_tasks, batch_size) or (batch_size, ) depending on whether num_tasks=1.
- If both num_updates and num_tasks are not 1, the returned input and target tensors will have shape (num_updates, num_tasks, batch_size).
- thresholds tensor always has shape (num_bins, ).
Args:
num_updates: the number of calls to update on each rank.
num_tasks: the number of tasks for the metric.
batch_size: batch size of the dataset.
num_bins: The number of bins.
device: device of the returned Tensors
Returns:
torch.Tensor: random feature data
torch.Tensor: random targets
torch.Tensor: thresholds
"""
input, target = get_rand_data_binary(num_updates, num_tasks, batch_size)
if device is None:
device = torch.device("cpu")

input, target = get_rand_data_binary(
num_updates, num_tasks, batch_size, device=device
)

threshold = torch.cat([torch.tensor([0, 1]), torch.rand(num_bins - 2)])
threshold, _ = torch.sort(threshold)
threshold = torch.unique(threshold)
return input, target, threshold
return input, target, threshold.to(device)

0 comments on commit 30c6954

Please sign in to comment.