diff --git a/kompute/op_rope.comp b/kompute/op_rope.comp index 3fa84f57988516..8c28546369b26a 100644 --- a/kompute/op_rope.comp +++ b/kompute/op_rope.comp @@ -63,6 +63,25 @@ void main() { out_[dst_data+1] = x0*sin_theta + x1*cos_theta; } } else { - // TODO: implement + const float inv_ndims = -1.f/pcs.n_dims; + for (uint ib = 0; ib < pcs.ne0/pcs.n_dims; ++ib) { + for (uint ic = 0; ic < pcs.n_dims; ic += 2) { + const float cos_theta = cos(theta); + const float sin_theta = sin(theta); + + theta *= theta_scale; + + const uint i0 = ib*pcs.n_dims + ic/2; + + const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inOff; // Based from in + const uint dst_data = uint((i3*pcs.nb3 + i2*pcs.nb2 + i1*pcs.nb1 + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_ + + const float x0 = in_[src]; + const float x1 = in_[src+pcs.n_dims/2]; + + out_[dst_data] = x0*cos_theta - x1*sin_theta; + out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta; + } + } } }