Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webgpu: Add non-shared MatMul program #6955

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Apply ping-pong pattern
  • Loading branch information
qjia7 committed Oct 17, 2022
commit 0cd8497b87717bbf4a30de1f0b04ad0092dae3af
39 changes: 29 additions & 10 deletions tfjs-backend-webgpu/src/matmul_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,37 @@ export class MatMulProgram implements WebGPUProgram {
var value: vec4<f32>;
// Loop over shared dimension.
// Compute value values for a single thread.
for (var k = 0; k < uniforms.dimInner; k = k + 4) {
let BCached0 = getB(${batchB}, k, globalCol);
let BCached1 = getB(${batchB}, k + 1, globalCol);
let BCached2 = getB(${batchB}, k + 2, globalCol);
let BCached3 = getB(${batchB}, k + 3, globalCol);
var k = 0;
var curBCached0 = getB(${batchB}, k, globalCol);
var curBCached1 = getB(${batchB}, k + 1, globalCol);
var curBCached2 = getB(${batchB}, k + 2, globalCol);
var curBCached3 = getB(${batchB}, k + 3, globalCol);
var curACached = getA(${batchA}, globalRow, k);
k = k + 4;
for (; k < uniforms.dimInner; k = k + 4) {
let nextBCached0 = getB(${batchB}, k, globalCol);
let nextBCached1 = getB(${batchB}, k + 1, globalCol);
let nextBCached2 = getB(${batchB}, k + 2, globalCol);
let nextBCached3 = getB(${batchB}, k + 3, globalCol);
let nextACached = getA(${batchA}, globalRow, k);

let ACached = getA(${batchA}, globalRow, k);
value = BCached0 * ACached.x + value;
value = BCached1 * ACached.y + value;
value = BCached2 * ACached.z + value;
value = BCached3 * ACached.w + value;
// Process data
value = curBCached0 * curACached.x + value;
value = curBCached1 * curACached.y + value;
value = curBCached2 * curACached.z + value;
value = curBCached3 * curACached.w + value;

curBCached0 = nextBCached0;
curBCached1 = nextBCached1;
curBCached2 = nextBCached2;
curBCached3 = nextBCached3;
curACached = nextACached;
}
// Process data for last iteration.
value = curBCached0 * curACached.x + value;
value = curBCached1 * curACached.y + value;
value = curBCached2 * curACached.z + value;
value = curBCached3 * curACached.w + value;
${biasActivationSnippet(this.addBias, this.activation)}
setOutputAtCoords(coords[0], coords[1], coords[2], value);
}
Expand Down