diff --git a/tests/compute/test_kvstore.py b/tests/compute/test_kvstore.py index c350b45721fe..345d08162465 100644 --- a/tests/compute/test_kvstore.py +++ b/tests/compute/test_kvstore.py @@ -68,6 +68,31 @@ def start_client(): assert torch.equal(new_tensor_2, target_tensor) == True + client.push_all('embed_0', client.pull_all('embed_0')) + client.push_all('embed_1', client.pull_all('embed_1')) + client.push_all('embed_2', client.pull_all('embed_2')) + + # Pull + tensor_id = torch.tensor([0, 1, 2, 6, 7, 8]) + new_tensor_0 = client.pull('embed_0', tensor_id) + new_tensor_1 = client.pull('embed_1', tensor_id) + new_tensor_2 = client.pull('embed_2', tensor_id) + + target_tensor = torch.tensor( + [[ 0., 0., 0.], + [ 10., 10., 10.], + [20., 20., 20.], + [ 0., 0., 0.], + [ 10., 10., 10.], + [20., 20., 20.]]) + + assert torch.equal(new_tensor_0, target_tensor) == True + assert torch.equal(new_tensor_1, target_tensor) == True + + target_tensor = tensor.tensor([20., 20., 20., 30., 30., 30.]) + + assert torch.equal(new_tensor_2, target_tensor) == True + client.shut_down() if __name__ == '__main__':