File tree Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Expand file tree Collapse file tree 2 files changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -1702,8 +1702,9 @@ kernel void kernel_rope(
1702
1702
dst_data[1 ] = x0*sin_theta + x1*cos_theta;
1703
1703
}
1704
1704
} else {
1705
- for (int64_t ib = 0 ; ib < ne0/n_dims; ++ib) {
1706
- for (int64_t ic = 2 *tiitg; ic < n_dims; ic += 2 *tptg.x ) {
1705
+ for (int64_t ic = 2 *tiitg; ic < ne0; ic += 2 *tptg.x ) {
1706
+ if (ic < n_dims) {
1707
+ const int64_t ib = 0 ;
1707
1708
1708
1709
// simplified from `(ib * n_dims + ic) * inv_ndims`
1709
1710
const float cur_rot = inv_ndims*ic - ib;
@@ -1722,6 +1723,14 @@ kernel void kernel_rope(
1722
1723
1723
1724
dst_data[0 ] = x0*cos_theta - x1*sin_theta;
1724
1725
dst_data[n_dims/2 ] = x0*sin_theta + x1*cos_theta;
1726
+ } else {
1727
+ const int64_t i0 = ic;
1728
+
1729
+ device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
1730
+ device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
1731
+
1732
+ dst_data[0 ] = src[0 ];
1733
+ dst_data[1 ] = src[1 ];
1725
1734
}
1726
1735
}
1727
1736
}
You can’t perform that action at this time.
0 commit comments