Skip to content

Commit

Permalink
Split folds when multiple domains helper (facebookresearch#3676)
Browse files Browse the repository at this point in the history
* Split folds when multiple domains helper

We've got some use cases where we've got a single dataset that has different sub-components that can be swapped in + out. It's a little subtle how we make sure we're getting the same samples out of every domain (namely, we can't concatinate them all together), so make a helper function for it.

Test Plan:
Print out outputs in a dataset that use this, with a varying # of domains but one fixed domain. Verify that the same lines show up for the fixed domain.

* address feedback

* minor - update total # of samples
  • Loading branch information
moyapchen authored Jun 2, 2021
1 parent d07d2dc commit 4be7a23
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
35 changes: 35 additions & 0 deletions parlai/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
"""
Utilities related to handling data.
"""
import random
from typing import List


class DatatypeHelper:
Expand Down Expand Up @@ -91,3 +93,36 @@ def is_streaming(cls, datatype: str) -> bool:
bool indicating whether we are streaming
"""
return 'stream' in datatype

@classmethod
def split_domains_by_fold(
cls,
fold: str,
domains: List[List],
train_frac: float,
valid_frac: float,
test_frac: float,
seed: int = 42,
):
"""
Need to be careful about how we setup random to not leak examples between trains
if we're in a scenario where a single dataset has different ways of mixing +
matching subcomponents.
"""
assert train_frac + valid_frac + test_frac == 1
if "train" in fold:
start = 0.0
end = train_frac
elif "valid" in fold:
start = train_frac
end = train_frac + valid_frac
else:
start = train_frac + valid_frac
end = 1.0

result = []
for domain in domains:
random.Random(seed).shuffle(domain)
result.extend(domain[int(start * len(domain)) : int(end * len(domain))])
random.Random(seed).shuffle(result)
return result
48 changes: 48 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from parlai.utils.misc import Timer, round_sigfigs, set_namedtuple_defaults, nice_report
import parlai.utils.strings as string_utils
from copy import deepcopy
import random
import time
import unittest
from parlai.utils.data import DatatypeHelper
Expand Down Expand Up @@ -191,6 +192,53 @@ def test_is_training(self):
assert DatatypeHelper.is_training("test") is False
assert DatatypeHelper.is_training("test:stream") is False

def test_split_domains_by_fold(self):
TOTAL_LEN = random.randint(100, 200)
a_end = random.randrange(1, TOTAL_LEN)
b_end = random.randrange(a_end, TOTAL_LEN)
DOMAIN_A = [i for i in range(0, a_end)]
DOMAIN_B = [i for i in range(a_end, b_end)]
DOMAIN_C = [i for i in range(b_end, TOTAL_LEN)]

DOMAINS_A = [deepcopy(DOMAIN_A)]
DOMAINS_A_B = [deepcopy(DOMAIN_A), deepcopy(DOMAIN_B)]
DOMAINS_C_B_A = [deepcopy(DOMAIN_C), deepcopy(DOMAIN_B), deepcopy(DOMAIN_A)]

train_frac = random.uniform(0, 1)
valid_frac = random.uniform(0, 1 - train_frac)
test_frac = 1 - train_frac - valid_frac

TRAIN_A = DatatypeHelper.split_domains_by_fold(
"train", DOMAINS_A, train_frac, valid_frac, test_frac
)
TRAIN_A_B = DatatypeHelper.split_domains_by_fold(
"train", DOMAINS_A_B, train_frac, valid_frac, test_frac
)
TRAIN_C_B_A = DatatypeHelper.split_domains_by_fold(
"train", deepcopy(DOMAINS_C_B_A), train_frac, valid_frac, test_frac
)

# Check to make sure selected values for a fold within a domain are consistent even if different domains are used, and presented in different orders
for val in DOMAIN_A:
state = bool(val in TRAIN_A)
assert bool(val in TRAIN_A_B) == state
assert bool(val in TRAIN_C_B_A) == state

for val in DOMAIN_B:
state = bool(val in TRAIN_A_B)
assert bool(val in TRAIN_C_B_A) == state

# Check that train + valid + test covers everything
VALID_C_B_A = DatatypeHelper.split_domains_by_fold(
"valid", deepcopy(DOMAINS_C_B_A), train_frac, valid_frac, test_frac
)
TEST_C_B_A = DatatypeHelper.split_domains_by_fold(
"test", deepcopy(DOMAINS_C_B_A), train_frac, valid_frac, test_frac
)

assert len(TRAIN_C_B_A) + len(VALID_C_B_A) + len(TEST_C_B_A) is TOTAL_LEN
assert len(set(TRAIN_C_B_A + VALID_C_B_A + TEST_C_B_A)) is TOTAL_LEN


if __name__ == '__main__':
unittest.main()

0 comments on commit 4be7a23

Please sign in to comment.