Skip to content

Commit

Permalink
Use the new asmjit to fix //deeplearning/fbgemm (pytorch#1077)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1077

Update asmjit dependencies within fbgemm

Reviewed By: jianyuh

Differential Revision: D35292923

fbshipit-source-id: d854a7d7772105085dad62d5e773dbe88816d78a
  • Loading branch information
r-barnes authored and facebook-github-bot committed Jul 8, 2022
1 parent cd9dceb commit 64a5c4a
Show file tree
Hide file tree
Showing 18 changed files with 72 additions and 56 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ if(NOT TARGET asmjit)

add_subdirectory("${ASMJIT_SRC_DIR}" "${FBGEMM_BINARY_DIR}/asmjit")
set_property(TARGET asmjit PROPERTY POSITION_INDEPENDENT_CODE ON)
# add a flag required for mac build
if(NOT MSVC)
target_compile_options(asmjit PRIVATE "-Wno-sign-conversion")
endif()
endif()

if(NOT TARGET cpuinfo)
Expand Down
2 changes: 1 addition & 1 deletion fbgemm_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ set_source_files_properties(
${gen_cpu_source_files}
PROPERTIES
INCLUDE_DIRECTORIES
"${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_SOURCE_DIR}/../include"
"${CMAKE_CURRENT_SOURCE_DIR};${CMAKE_CURRENT_SOURCE_DIR}/include;${CMAKE_CURRENT_SOURCE_DIR}/../include;${THIRDPARTY}/asmjit/src"
)

set_source_files_properties(
Expand Down
24 changes: 14 additions & 10 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@
* LICENSE file in the root directory of this source tree.
*/
#pragma once

#include "./FbgemmBuild.h"
#include "./UtilsAvx2.h"

#include <algorithm>
#include <array>
#include <cmath>
#include <string>
#include <type_traits>
#include "./FbgemmBuild.h"
#include "./UtilsAvx2.h"

// forward declarations to asmjit
namespace asmjit {
namespace x86 {
class Xmm;
class Ymm;
class Zmm;
} // namespace x86
} // namespace asmjit
#include <asmjit/asmjit.h>

// // forward declarations to asmjit
// namespace asmjit {
// namespace x86 {
// class Xmm;
// class Ymm;
// class Zmm;
// } // namespace x86
// } // namespace asmjit

namespace fbgemm {

Expand Down
10 changes: 5 additions & 5 deletions src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ GenEmbeddingSpMDMLookup<
const float*, // weights
outType*, // out
const int32_t*, // compressed_indices_table and then mask
const int*>(asmjit::CallConv::kIdHost),
const int*>(asmjit::CallConvId::kHost),
a->environment());
} else {
func.init(
Expand All @@ -327,7 +327,7 @@ GenEmbeddingSpMDMLookup<
const offsetType*, // offsets or lengths
const float*, // weights
outType*, // out and then mask
const int*>(asmjit::CallConv::kIdHost),
const int*>(asmjit::CallConvId::kHost),
a->environment());
}

Expand All @@ -336,20 +336,20 @@ GenEmbeddingSpMDMLookup<

if (instSet == inst_set_t::avx2) {
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
} else {
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
}

frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
reg_id == 15
? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)
: asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
Expand Down
8 changes: 4 additions & 4 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ GenEmbeddingSpMDMNBitLookup<
const float*, // weights
float*, // out
const int32_t* /* compressed_indices_table */,
const int* /* mask */>(asmjit::CallConv::kIdHost),
const int* /* mask */>(asmjit::CallConvId::kHost),
a->environment());
} else {
func.init(
Expand All @@ -313,22 +313,22 @@ GenEmbeddingSpMDMNBitLookup<
const offsetType*, // offsets or lengths
const float*, // weights
float*, // out
const int* /* mask */>(asmjit::CallConv::kIdHost),
const int* /* mask */>(asmjit::CallConvId::kHost),
a->environment());
}

asmjit::FuncFrame frame;
frame.init(func);

frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));

frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
reg_id == 15
? asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)
: asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
Expand Down
6 changes: 3 additions & 3 deletions src/FbgemmI64.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,20 @@ CodeGenBase<int64_t, int64_t, int64_t, int64_t>::getOrCreate(
int64_t*,
int64_t*,
int,
int>(asmjit::CallConv::kIdHost),
int>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
frame.init(func);

frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down
6 changes: 3 additions & 3 deletions src/GenerateI8Depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,18 +271,18 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
int,
const int*,
int,
const std::int32_t*>(asmjit::CallConv::kIdHost),
const std::int32_t*>(asmjit::CallConvId::kHost),
e->environment());

asmjit::FuncFrame frame;
frame.init(func);

frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down
12 changes: 6 additions & 6 deletions src/GenerateKernelDirectConvU8S8S32ACC32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreateDirectConv(
int8_t*,
int32_t*,
int,
int>(asmjit::CallConv::kIdHost),
int>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
Expand All @@ -246,9 +246,9 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreateDirectConv(
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31);
}

frame.setDirtyRegs(x86::Reg::kGroupVec, dirtyVecRegs);
frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs);
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down Expand Up @@ -660,7 +660,7 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
int,
int,
int,
int>(asmjit::CallConv::kIdHost),
int>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
Expand All @@ -673,9 +673,9 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31);
}

frame.setDirtyRegs(x86::Reg::kGroupVec, dirtyVecRegs);
frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs);
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down
7 changes: 4 additions & 3 deletions src/GenerateKernelU8S8S32ACC16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,18 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
int8_t*,
int32_t*,
int,
int>(asmjit::CallConv::kIdHost),
int>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
frame.init(func);
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
frame.setDirtyRegs(
x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));

asmjit::FuncArgsAssignment args(&func);
args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg);
Expand Down
6 changes: 3 additions & 3 deletions src/GenerateKernelU8S8S32ACC16Avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,20 +151,20 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate(
int8_t*,
int32_t*,
int,
int>(asmjit::CallConv::kIdHost),
int>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
frame.init(func);

frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down
6 changes: 3 additions & 3 deletions src/GenerateKernelU8S8S32ACC32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate(
int8_t*,
int32_t*,
int,
int>(asmjit::CallConv::kIdHost),
int>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
Expand All @@ -201,9 +201,9 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate(
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31);
}

frame.setDirtyRegs(x86::Reg::kGroupVec, dirtyVecRegs);
frame.setDirtyRegs(asmjit::RegGroup::kVec, dirtyVecRegs);
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down
6 changes: 3 additions & 3 deletions src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,20 +142,20 @@ CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreate(
int8_t*,
int32_t*,
int,
int>(asmjit::CallConv::kIdHost),
int>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
frame.init(func);

frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down
6 changes: 3 additions & 3 deletions src/GroupwiseConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -224,17 +224,17 @@ jit_conv_kernel_fp GenConvKernel<SPATIAL_DIM, INST_SET>::getOrCreate() {
int32_t,
int32_t,
int32_t,
int32_t*>(asmjit::CallConv::kIdHost),
int32_t*>(asmjit::CallConvId::kHost),
a->environment());

frame_.init(func_);

frame_.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
frame_.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func_);
Expand Down
8 changes: 4 additions & 4 deletions src/RowWiseSparseAdagradFused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,28 +177,28 @@ typename ReturnFunctionSignature<indxType, offsetType, dataType>::
const int*, // lengths
float, // epsilon
float, // lr then rand_buffer
uint32_t*>(asmjit::CallConv::kIdHost),
uint32_t*>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
frame.init(func);

if (instSet == inst_set_t::avx2) {
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
} else {
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
}

frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14));

asmjit::FuncArgsAssignment args(&func);
Expand Down
8 changes: 4 additions & 4 deletions src/SparseAdagrad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,28 +510,28 @@ GenSparseAdagrad<indxType, instSet>::getOrCreate(
const int*, // mask_avx2
float, // weight_decay
const double*, // counter then counter_halflife
std::int64_t>(asmjit::CallConv::kIdHost),
std::int64_t>(asmjit::CallConvId::kHost),
a->environment());

asmjit::FuncFrame frame;
frame.init(func);

if (instSet == inst_set_t::avx2) {
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));
} else {
frame.setDirtyRegs(
x86::Reg::kGroupVec,
asmjit::RegGroup::kVec,
asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) |
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15) |
asmjit::Support::bitMask(16, 17, 18, 19, 20, 21, 22, 23) |
asmjit::Support::bitMask(24, 25, 26, 27, 28, 29, 30, 31));
}

frame.setDirtyRegs(
x86::Reg::kGroupGp,
asmjit::RegGroup::kGp,
asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15));

asmjit::FuncArgsAssignment args(&func);
Expand Down
Loading

0 comments on commit 64a5c4a

Please sign in to comment.