Skip to content

Commit

Permalink
Bug 1778751 - Implement BFloat16 product instruction. r=jseward
Browse files Browse the repository at this point in the history
  • Loading branch information
yurydelendik committed Oct 11, 2022
1 parent ad795fd commit 92f834f
Show file tree
Hide file tree
Showing 18 changed files with 160 additions and 4 deletions.
1 change: 1 addition & 0 deletions js/src/jit-test/lib/wasm-binary.js
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ const F64x2RelaxedMaxCode = 0x110;
const I16x8RelaxedQ15MulrS = 0x111;
const I16x8DotI8x16I7x16S = 0x112;
const I32x4DotI8x16I7x16AddS = 0x113;
const F32x4RelaxedDotBF16x8AddF32x4 = 0x114;

const FirstInvalidOpcode = 0xc5;
const LastInvalidOpcode = 0xfa;
Expand Down
2 changes: 1 addition & 1 deletion js/src/jit-test/tests/wasm/binary.js
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ if (!wasmSimdEnabled()) {
let reservedSimd = [
0x9a, 0xa2, 0xa5, 0xa6, 0xaf, 0xb0, 0xb2, 0xb3, 0xb4, 0xbb,
0xc2, 0xc5, 0xc6, 0xcf, 0xd0, 0xd2, 0xd3, 0xd4, 0xe2, 0xee,
0x114, 0x115, 0x116, 0x117,
0x115, 0x116, 0x117,
0x118, 0x119, 0x11a, 0x11b, 0x11c, 0x11d, 0x11e, 0x11f,
0x120, 0x121, 0x122, 0x123, 0x124, 0x125, 0x126, 0x127,
0x128, 0x129, 0x12a, 0x12b, 0x12c, 0x12d, 0x12e, 0x12f,
Expand Down
63 changes: 63 additions & 0 deletions js/src/jit-test/tests/wasm/simd/experimental.js
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,69 @@ for ( let [opcode, xs, ys, as, operator] of [[F32x4RelaxedFmaCode, fxs, fys, fas
SimdPrefix, varU32(opcode)])]})])])));
}

// BFloat16 Dot Product, https://github.com/WebAssembly/relaxed-simd/issues/77
function createBF16Array(arr) {
const data = new Uint32Array(new Float32Array(arr).buffer);
const result = new Uint16Array(data.length);
for (let i = 0; i < result.length; i++) {
result[i] = data[i] >>> 16;
}
return result;
}

// Asserts that two 32-bit floats arrays are almost the same, taking into
// account that numbers in the `actual` array are participated in BF16
// operations -- low 16 bits of the mantissa were ignored.
function assertAlmostSameAsBF16(actual, expected) {
const a = new Int32Array(new Float32Array(actual).buffer);
const b = new Int32Array(new Float32Array(expected).buffer);
for (let i = 0; i < a.length; i++) {
// Comparing IEEE 754 32-bit numbers as integer numbers. Only low/middle
// portion of mantissa part may be different. If `a[i]` and `b[i]` floats
// are almost the same: the exponent, sign, and mantissa high bits are
// about equal. Their difference will be a int number that represents
// an error in low bits of mantissa. Choosing 0x40000 as max acceptable
// difference -- 16 low bits will be discarded in BF16 calculations plus
// additional 2 bits for rounding errors.
assertEq(Math.abs(a[i] - b[i]) < 0x40000, true,
`actual: ${actual[i]}/${a[i]}, expected: ${expected[i]}/${b[i]}` );
}
}

var ins = wasmValidateAndEval(moduleWithSections([
sigSection([v2vSig]),
declSection([0]),
memorySection(1),
exportSection([{funcIndex: 0, name: "run"},
{memIndex: 0, name: "mem"}]),
bodySection([
funcBody({locals:[],
body: [...V128StoreExpr(0, [...V128Load(16),
...V128Load(32),
...V128Load(48),
SimdPrefix, varU32(F32x4RelaxedDotBF16x8AddF32x4)])]})])]));

var mem16 = new Uint16Array(ins.exports.mem.buffer);
var mem32 = new Float32Array(ins.exports.mem.buffer);
set(mem16, 16/2, createBF16Array([1.0, 0.5, -2.0, 100.0, -1e+30, 0, 0, 4e+20]));
set(mem16, 32/2, createBF16Array([0.0, 1.0, 2.0, 3.0, 1e-29, 0, 0, 3e+10]));
set(mem32, 48/4, [0.0, -1.0, 2.0, -3.0]);
ins.exports.run();
var result = get(mem32, 0, 4);
assertAlmostSameAsBF16(result, [0.5, 295.0, -8.0, 11.9e+30], 0.1);

assertEq(false, WebAssembly.validate(moduleWithSections([
sigSection([v2vSig]),
declSection([0]),
memorySection(1),
exportSection([{funcIndex: 0, name: "run"},
{memIndex: 0, name: "mem"}]),
bodySection([
funcBody({locals:[],
body: [...V128StoreExpr(0, [...V128Load(0),
...V128Load(0),
SimdPrefix, varU32(F32x4RelaxedDotBF16x8AddF32x4)])]})])])));


// Relaxed swizzle, https://github.com/WebAssembly/relaxed-simd/issues/22

Expand Down
4 changes: 4 additions & 0 deletions js/src/jit/MacroAssembler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3435,6 +3435,10 @@ class MacroAssembler : public MacroAssemblerSpecific {
FloatRegister dest, FloatRegister temp)
DEFINED_ON(arm64);

inline void dotBFloat16x8ThenAdd(FloatRegister lhs, FloatRegister rhs,
FloatRegister dest, FloatRegister temp)
DEFINED_ON(x86_shared, arm64);

// Floating point rounding

inline void ceilFloat32x4(FloatRegister src, FloatRegister dest)
Expand Down
5 changes: 5 additions & 0 deletions js/src/jit/arm64/CodeGenerator-arm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3041,6 +3041,11 @@ void CodeGenerator::visitWasmTernarySimd128(LWasmTernarySimd128* ins) {
ToFloatRegister(ins->v0()), ToFloatRegister(ins->v1()),
ToFloatRegister(ins->v2()), ToFloatRegister(ins->temp()));
break;
case wasm::SimdOp::F32x4RelaxedDotBF16x8AddF32x4:
masm.dotBFloat16x8ThenAdd(
ToFloatRegister(ins->v0()), ToFloatRegister(ins->v1()),
ToFloatRegister(ins->v2()), ToFloatRegister(ins->temp()));
break;
default:
MOZ_CRASH("NYI");
}
Expand Down
7 changes: 7 additions & 0 deletions js/src/jit/arm64/Lowering-arm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,13 @@ void LIRGenerator::visitWasmTernarySimd128(MWasmTernarySimd128* ins) {
defineReuseInput(lir, ins, LWasmTernarySimd128::V2);
break;
}
case wasm::SimdOp::F32x4RelaxedDotBF16x8AddF32x4: {
auto* lir = new (alloc()) LWasmTernarySimd128(
ins->simdOp(), useRegister(ins->v0()), useRegister(ins->v1()),
useRegisterAtStart(ins->v2()), tempSimd128());
defineReuseInput(lir, ins, LWasmTernarySimd128::V2);
break;
}
default:
MOZ_CRASH("NYI");
}
Expand Down
14 changes: 14 additions & 0 deletions js/src/jit/arm64/MacroAssembler-arm64-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3748,6 +3748,20 @@ void MacroAssembler::dotInt8x16Int7x16ThenAdd(FloatRegister lhs,
Sadalp(Simd4S(dest), Simd8H(temp));
}

void MacroAssembler::dotBFloat16x8ThenAdd(FloatRegister lhs, FloatRegister rhs,
FloatRegister dest,
FloatRegister temp) {
MOZ_ASSERT(lhs != dest && rhs != dest);
ScratchSimd128Scope scratch(*this);
Shl(Simd4S(scratch), Simd4S(lhs), 16);
Shl(Simd4S(temp), Simd4S(rhs), 16);
Fmla(Simd4S(dest), Simd4S(scratch), Simd4S(temp));
loadConstantSimd128(SimdConstant::SplatX4(int32_t(0xFFFF0000)), temp);
And(Simd16B(scratch), Simd16B(lhs), Simd16B(temp));
And(Simd16B(temp), Simd16B(rhs), Simd16B(temp));
Fmla(Simd4S(dest), Simd4S(scratch), Simd4S(temp));
}

// Floating point rounding (experimental as of August, 2020)
// https://github.com/WebAssembly/simd/pull/232

Expand Down
5 changes: 5 additions & 0 deletions js/src/jit/x86-shared/CodeGenerator-x86-shared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2282,6 +2282,11 @@ void CodeGenerator::visitWasmTernarySimd128(LWasmTernarySimd128* ins) {
ToFloatRegister(ins->v1()),
ToFloatRegister(ins->v2()));
break;
case wasm::SimdOp::F32x4RelaxedDotBF16x8AddF32x4:
masm.dotBFloat16x8ThenAdd(
ToFloatRegister(ins->v0()), ToFloatRegister(ins->v1()),
ToFloatRegister(ins->v2()), ToFloatRegister(ins->temp()));
break;
default:
MOZ_CRASH("NYI");
}
Expand Down
7 changes: 7 additions & 0 deletions js/src/jit/x86-shared/Lowering-x86-shared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,13 @@ void LIRGenerator::visitWasmTernarySimd128(MWasmTernarySimd128* ins) {
}
break;
}
case wasm::SimdOp::F32x4RelaxedDotBF16x8AddF32x4: {
auto* lir = new (alloc()) LWasmTernarySimd128(
ins->simdOp(), useRegister(ins->v0()), useRegister(ins->v1()),
useRegisterAtStart(ins->v2()), tempSimd128());
defineReuseInput(lir, ins, LWasmTernarySimd128::V2);
break;
}
default:
MOZ_CRASH("NYI");
}
Expand Down
19 changes: 19 additions & 0 deletions js/src/jit/x86-shared/MacroAssembler-x86-shared-SIMD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1480,3 +1480,22 @@ void MacroAssemblerX86Shared::popcntInt8x16(FloatRegister src,
vpshufb(scratch, temp, temp);
vpaddb(Operand(temp), output, output);
}

void MacroAssemblerX86Shared::dotBFloat16x8ThenAdd(FloatRegister lhs,
FloatRegister rhs,
FloatRegister dest,
FloatRegister temp) {
ScratchSimd128Scope scratch(asMasm());
FloatRegister lhsCopy = asMasm().moveSimd128IntIfNotAVX(lhs, scratch);
FloatRegister rhsCopy = asMasm().moveSimd128IntIfNotAVX(rhs, temp);
vpslld(Imm32(16), lhsCopy, scratch);
vpslld(Imm32(16), rhsCopy, temp);
asMasm().mulFloat32x4(scratch, temp, scratch);
asMasm().addFloat32x4(dest, scratch, dest);
// The temp has 0 in low half-word. Use pblendw instead of `& 0xFFFF0000`.
FloatRegister tempCopy = asMasm().moveSimd128IntIfNotAVX(temp, scratch);
vpblendw(0xAA, lhs, tempCopy, scratch);
vpblendw(0xAA, rhs, temp, temp);
asMasm().mulFloat32x4(scratch, temp, scratch);
asMasm().addFloat32x4(dest, scratch, dest);
}
7 changes: 7 additions & 0 deletions js/src/jit/x86-shared/MacroAssembler-x86-shared-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2842,6 +2842,13 @@ void MacroAssembler::dotInt8x16Int7x16ThenAdd(FloatRegister lhs,
vpaddd(Operand(scratch), dest, dest);
}

void MacroAssembler::dotBFloat16x8ThenAdd(FloatRegister lhs, FloatRegister rhs,
FloatRegister dest,
FloatRegister temp) {
MOZ_ASSERT(lhs != dest && rhs != dest);
MacroAssemblerX86Shared::dotBFloat16x8ThenAdd(lhs, rhs, dest, temp);
}

// Rounding

void MacroAssembler::ceilFloat32x4(FloatRegister src, FloatRegister dest) {
Expand Down
2 changes: 2 additions & 0 deletions js/src/jit/x86-shared/MacroAssembler-x86-shared.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,8 @@ class MacroAssemblerX86Shared : public Assembler {
FloatRegister output);
void popcntInt8x16(FloatRegister src, FloatRegister temp,
FloatRegister output);
void dotBFloat16x8ThenAdd(FloatRegister lhs, FloatRegister rhs,
FloatRegister dest, FloatRegister temp);

// SIMD inline methods private to the implementation, that appear to be used.

Expand Down
1 change: 1 addition & 0 deletions js/src/wasm/WasmBCClass.h
Original file line number Diff line number Diff line change
Expand Up @@ -1669,6 +1669,7 @@ struct BaseCompiler final {
void emitVectorAndNot();
# ifdef ENABLE_WASM_RELAXED_SIMD
void emitDotI8x16I7x16AddS();
void emitDotBF16x8AddF32x4();
# endif

void loadSplat(MemoryAccessDesc* access);
Expand Down
17 changes: 17 additions & 0 deletions js/src/wasm/WasmBaselineCompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8310,6 +8310,18 @@ void BaseCompiler::emitDotI8x16I7x16AddS() {
freeV128(rs0);
pushV128(rsd);
}

void BaseCompiler::emitDotBF16x8AddF32x4() {
RegV128 rsd = popV128();
RegV128 rs0, rs1;
pop2xV128(&rs0, &rs1);
RegV128 temp = needV128();
masm.dotBFloat16x8ThenAdd(rs0, rs1, rsd, temp);
freeV128(temp);
freeV128(rs1);
freeV128(rs0);
pushV128(rsd);
}
# endif // ENABLE_WASM_RELAXED_SIMD

void BaseCompiler::emitVectorAndNot() {
Expand Down Expand Up @@ -10028,6 +10040,11 @@ bool BaseCompiler::emitBody() {
return iter_.unrecognizedOpcode(&op);
}
CHECK_NEXT(dispatchTernary0(emitDotI8x16I7x16AddS, ValType::V128));
case uint32_t(SimdOp::F32x4RelaxedDotBF16x8AddF32x4):
if (!moduleEnv_.v128RelaxedEnabled()) {
return iter_.unrecognizedOpcode(&op);
}
CHECK_NEXT(dispatchTernary0(emitDotBF16x8AddF32x4, ValType::V128));
# endif
default:
break;
Expand Down
3 changes: 2 additions & 1 deletion js/src/wasm/WasmConstants.h
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,8 @@ enum class SimdOp {
I16x8RelaxedQ15MulrS = 0x111,
I16x8DotI8x16I7x16S = 0x112,
I32x4DotI8x16I7x16AddS = 0x113,
// bfloat16 dot product = 0x114
F32x4RelaxedDotBF16x8AddF32x4 = 0x114,

// Reserved for Relaxed SIMD = 0x115-0x12f

// Unused = 0x130 and up
Expand Down
3 changes: 2 additions & 1 deletion js/src/wasm/WasmIonCompile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6655,7 +6655,8 @@ static bool EmitBodyExprs(FunctionCompiler& f) {
case uint32_t(SimdOp::I16x8RelaxedLaneSelect):
case uint32_t(SimdOp::I32x4RelaxedLaneSelect):
case uint32_t(SimdOp::I64x2RelaxedLaneSelect):
case uint32_t(SimdOp::I32x4DotI8x16I7x16AddS): {
case uint32_t(SimdOp::I32x4DotI8x16I7x16AddS):
case uint32_t(SimdOp::F32x4RelaxedDotBF16x8AddF32x4): {
if (!f.moduleEnv().v128RelaxedEnabled()) {
return f.iter().unrecognizedOpcode(&op);
}
Expand Down
1 change: 1 addition & 0 deletions js/src/wasm/WasmOpIter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ OpKind wasm::Classify(OpBytes op) {
case SimdOp::I32x4RelaxedLaneSelect:
case SimdOp::I64x2RelaxedLaneSelect:
case SimdOp::I32x4DotI8x16I7x16AddS:
case SimdOp::F32x4RelaxedDotBF16x8AddF32x4:
WASM_SIMD_OP(OpKind::Ternary);
}
break;
Expand Down
3 changes: 2 additions & 1 deletion js/src/wasm/WasmValidate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1026,7 +1026,8 @@ static bool DecodeFunctionBodyExprs(const ModuleEnvironment& env,
case uint32_t(SimdOp::I16x8RelaxedLaneSelect):
case uint32_t(SimdOp::I32x4RelaxedLaneSelect):
case uint32_t(SimdOp::I64x2RelaxedLaneSelect):
case uint32_t(SimdOp::I32x4DotI8x16I7x16AddS): {
case uint32_t(SimdOp::I32x4DotI8x16I7x16AddS):
case uint32_t(SimdOp::F32x4RelaxedDotBF16x8AddF32x4): {
if (!env.v128RelaxedEnabled()) {
return iter.unrecognizedOpcode(&op);
}
Expand Down

0 comments on commit 92f834f

Please sign in to comment.