Skip to content

Commit

Permalink
vulkan: implement neox mode for rope
Browse files Browse the repository at this point in the history
  • Loading branch information
apage43 authored and cebtenzzre committed Oct 5, 2023
1 parent f9d41c7 commit 7d4ecef
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion kompute/op_rope.comp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ void main() {
out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
}
} else {
// TODO: implement
const float inv_ndims = -1.f/pcs.n_dims;
for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) {
for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
const float cos_theta = cos(theta);
const float sin_theta = sin(theta);

theta *= theta_scale;

const uint i0 = ib*pcs.n_dims + ic/2;

const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inOff; // Based from in
const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_

const float x0 = in_[src];
const float x1 = in_[src+pcs.n_dims/2];

out_[dst_data] = x0*cos_theta - x1*sin_theta;
out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
}
}
}
}

0 comments on commit 7d4ecef

Please sign in to comment.