Skip to content

Commit

Permalink
[fbgemm_gpu] Re-enable cache tests for ROCm
Browse files Browse the repository at this point in the history
- Re-enable UVM cache tests for ROCm
  • Loading branch information
q10 committed Nov 15, 2024
1 parent 2da2b7a commit b1ee2b5
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
16 changes: 8 additions & 8 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,17 @@ __determine_test_directories () {

if [ "$fbgemm_gpu_variant" != "genai" ]; then
target_directories+=(
fbgemm_gpu/test
fbgemm_gpu/test/tbe/cache
)
fi

if [ "$fbgemm_gpu_variant" == "cuda" ] || [ "$fbgemm_gpu_variant" == "genai" ]; then
target_directories+=(
fbgemm_gpu/experimental/example/test
fbgemm_gpu/experimental/gemm/test
fbgemm_gpu/experimental/gen_ai/test
)
fi
# if [ "$fbgemm_gpu_variant" == "cuda" ] || [ "$fbgemm_gpu_variant" == "genai" ]; then
# target_directories+=(
# fbgemm_gpu/experimental/example/test
# fbgemm_gpu/experimental/gemm/test
# fbgemm_gpu/experimental/gen_ai/test
# )
# fi

echo "[TEST] Determined the testing directories:"
for test_dir in "${target_directories[@]}"; do
Expand Down
2 changes: 2 additions & 0 deletions fbgemm_gpu/test/tbe/cache/cache_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,8 @@ def generate_cache_tbes(
def assert_cache(
tensor_a: torch.Tensor, tensor_b: torch.Tensor, stochastic_rounding: bool
) -> None:
print(f"\n\ntensor_a: {tensor_a.shape}")
print(f"\n\ntensor_b: {tensor_b.shape}")
if stochastic_rounding:
# Stochastic rounding randomly alters the mantissa bits during the
# FP32->FP16 conversion in TBE backward, resulting in non-deterministic
Expand Down
8 changes: 4 additions & 4 deletions fbgemm_gpu/test/tbe/cache/cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _compute_grad_output_shape(

@optests.dontGenerateOpCheckTests("Serial OOM")
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
# @unittest.skipIf(*running_on_rocm)
@given(
T=st.integers(min_value=1, max_value=5),
D=st.integers(min_value=2, max_value=256),
Expand Down Expand Up @@ -443,7 +443,7 @@ def assert_event_not_exist(event_name: str) -> None:

@optests.dontGenerateOpCheckTests("Serial OOM")
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
# @unittest.skipIf(*running_on_rocm)
@given(
T=st.integers(min_value=1, max_value=5),
D=st.integers(min_value=2, max_value=256),
Expand Down Expand Up @@ -471,7 +471,7 @@ def test_cache_prefetch_pipeline(

@optests.dontGenerateOpCheckTests("Serial OOM")
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
# @unittest.skipIf(*running_on_rocm)
@given(
T=st.integers(min_value=1, max_value=5),
D=st.integers(min_value=2, max_value=256),
Expand Down Expand Up @@ -500,7 +500,7 @@ def test_cache_prefetch_pipeline_stream_1(

@optests.dontGenerateOpCheckTests("Serial OOM")
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
# @unittest.skipIf(*running_on_rocm)
@given(
T=st.integers(min_value=1, max_value=5),
D=st.integers(min_value=2, max_value=256),
Expand Down

0 comments on commit b1ee2b5

Please sign in to comment.