Skip to content

Commit

Permalink
Don't call avx512 implementation if we didn't build it (pytorch#1930)
Browse files Browse the repository at this point in the history
Summary:
You can't use `__AVX512F__` because EmbeddingSpMDM isn't actually
built with -mavx512f.  Carefully selected the macro to be NO_AVX512
so that I don't have to modify the fbcode build (where we assume
we have AVX512 compiler support).

Signed-off-by: Edward Z. Yang <[email protected]>

Pull Request resolved: pytorch#1930

Reviewed By: brad-mengchi, excelle08, r-barnes, shintaro-iwasaki

Differential Revision: D48243887

Pulled By: ezyang

fbshipit-source-id: d3b6ca16f7eb44966ce810ab35b10164ef8bebb5
  • Loading branch information
ezyang authored and facebook-github-bot committed Aug 11, 2023
1 parent 7e31f39 commit 1b2746f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
2 changes: 2 additions & 0 deletions fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ if(NOT USE_ROCM AND CXX_AVX512_FOUND)
${fbgemm_sources}
${fbgemm_sources_avx2}
${fbgemm_sources_avx512})
else()
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNO_AVX512=1")
endif()

set(fbgemm_sources_include_directories
Expand Down
29 changes: 16 additions & 13 deletions src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,7 @@ void compressed_indices_remap(
}

#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
#ifndef NO_AVX512
const inst_set_t isa = fbgemmInstructionSet();
if (isZmm(isa)) {
#ifndef __HIP_PLATFORM_HCC__
Expand All @@ -1604,6 +1605,7 @@ void compressed_indices_remap(
out_indices,
out_offsets,
out_weights);
return;
} else {
internal::compressed_indices_remap_avx512<IndexType, true>(
offsets_len,
Expand All @@ -1614,22 +1616,23 @@ void compressed_indices_remap(
out_indices,
out_offsets,
out_weights);
return;
}
#endif // __HIP_PLATFORM_HCC__
} else {
#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
compressed_indices_remap_ref<IndexType>(
offsets_len,
indices,
compressed_indices_mapping,
offsets,
weights,
out_indices,
out_offsets,
out_weights);
#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64
}
#endif
#endif // NO_AVX512
#endif // CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64

// Non-vectorized fallback implementation
compressed_indices_remap_ref<IndexType>(
offsets_len,
indices,
compressed_indices_mapping,
offsets,
weights,
out_indices,
out_offsets,
out_weights);
}

#define INSTANTIATE_REMAP_BASE(INDEX_TYPE) \
Expand Down

0 comments on commit 1b2746f

Please sign in to comment.