diff --git a/.gitignore b/.gitignore index 20b47c63..b5993a7f 100644 --- a/.gitignore +++ b/.gitignore @@ -90,3 +90,5 @@ fastlane/test_output iOSInjectionProject/ .swiftpm + +.vscode/ diff --git a/Package.swift b/Package.swift index ce7c6d73..154b25e0 100644 --- a/Package.swift +++ b/Package.swift @@ -73,8 +73,10 @@ let package = Package( "mlx/tests", // opt-out of these backends (using metal) - "mlx/mlx/backend/no_metal", "mlx/mlx/backend/no_cpu", + "mlx/mlx/backend/no_gpu", + "mlx/mlx/backend/cuda", + "mlx/mlx/backend/metal/no_metal.cpp", // build variants (we are opting _out_ of these) "mlx/mlx/io/no_safetensors.cpp", @@ -111,7 +113,7 @@ let package = Package( .define("_METAL_"), .define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""), .define("METAL_PATH", to: "\"default.metallib\""), - .define("MLX_VERSION", to: "\"0.24.2\""), + .define("MLX_VERSION", to: "\"0.26.2\""), ], linkerSettings: [ .linkedFramework("Foundation"), diff --git a/Source/Cmlx/mlx b/Source/Cmlx/mlx index eaf709b8..58f38603 160000 --- a/Source/Cmlx/mlx +++ b/Source/Cmlx/mlx @@ -1 +1 @@ -Subproject commit eaf709b83e559079e212699bfc9dd2f939d25c9a +Subproject commit 58f38603066b589341429efa96fc77619b82979e diff --git a/Source/Cmlx/mlx-generated/binary.cpp b/Source/Cmlx/mlx-generated/binary.cpp index fa430492..43b99fcc 100644 --- a/Source/Cmlx/mlx-generated/binary.cpp +++ b/Source/Cmlx/mlx-generated/binary.cpp @@ -10,59 +10,116 @@ template uint index [[thread_position_in_grid]]) { c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } } template [[kernel]] void binary_g_nd1( diff --git a/Source/Cmlx/mlx-generated/binary_ops.cpp b/Source/Cmlx/mlx-generated/binary_ops.cpp index d23ce5f3..2e997f89 100644 --- a/Source/Cmlx/mlx-generated/binary_ops.cpp +++ b/Source/Cmlx/mlx-generated/binary_ops.cpp @@ -210,6 +210,13 @@ struct Power { } template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/Source/Cmlx/mlx-generated/binary_two.cpp b/Source/Cmlx/mlx-generated/binary_two.cpp index 07a8138f..57778976 100644 --- a/Source/Cmlx/mlx-generated/binary_two.cpp +++ b/Source/Cmlx/mlx-generated/binary_two.cpp @@ -13,77 +13,146 @@ template c[index] = out[0]; d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } template [[kernel]] void binary_g_nd1( diff --git a/Source/Cmlx/mlx-generated/conv.cpp b/Source/Cmlx/mlx-generated/conv.cpp index 3e1f1d60..8fc3f20f 100644 --- a/Source/Cmlx/mlx-generated/conv.cpp +++ b/Source/Cmlx/mlx-generated/conv.cpp @@ -353,6 +353,7 @@ struct Conv2DWeightBlockLoader { const device T* src; const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; METAL_FUNC Conv2DWeightBlockLoader( @@ -371,6 +372,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} METAL_FUNC void load_unsafe() const { @@ -400,11 +402,11 @@ struct Conv2DWeightBlockLoader { } METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } @@ -604,7 +606,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { } return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { #pragma clang loop unroll(full) for (short i = 0; i < BROWS; i += TROWS) { diff --git a/Source/Cmlx/mlx-generated/copy.cpp b/Source/Cmlx/mlx-generated/copy.cpp index 9ac729f1..260f6789 100644 --- a/Source/Cmlx/mlx-generated/copy.cpp +++ b/Source/Cmlx/mlx-generated/copy.cpp @@ -2,37 +2,75 @@ namespace mlx::core::metal { const char* copy() { return R"preamble( -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } } template [[kernel]] void copy_g_nd1( diff --git a/Source/Cmlx/mlx-generated/fft.cpp b/Source/Cmlx/mlx-generated/fft.cpp index aaac34cb..065a4f41 100644 --- a/Source/Cmlx/mlx-generated/fft.cpp +++ b/Source/Cmlx/mlx-generated/fft.cpp @@ -314,7 +314,7 @@ struct ReadWriter { return grid_index >= batch_size; } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; constexpr int read_width = 2; @@ -333,7 +333,7 @@ struct ReadWriter { } } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; constexpr int read_width = 2; @@ -352,7 +352,7 @@ struct ReadWriter { } } METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; threadgroup float2* seq_buf = buf + elem.y * n; @@ -367,7 +367,7 @@ struct ReadWriter { } } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -437,7 +437,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { } template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = @@ -453,7 +453,8 @@ METAL_FUNC void ReadWriter::load() const { template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_out = @@ -480,7 +481,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = @@ -503,8 +504,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; short next_out = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 @@ -540,7 +541,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = @@ -588,8 +590,8 @@ METAL_FUNC void ReadWriter::load_padded( const device float2* w_k) const { int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; short next_in = batch_size % 2 == 1 && grid_index * 2 == batch_size - 1 @@ -627,7 +629,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; short next_out = diff --git a/Source/Cmlx/mlx-generated/hadamard.cpp b/Source/Cmlx/mlx-generated/hadamard.cpp index b0839df9..e2450ba5 100644 --- a/Source/Cmlx/mlx-generated/hadamard.cpp +++ b/Source/Cmlx/mlx-generated/hadamard.cpp @@ -22,7 +22,7 @@ METAL_FUNC void radix_func(thread float* x) { h <<= 1; } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -35,15 +35,22 @@ template constexpr short num_steps = logN / logR; constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; + if (stride == 1) { #pragma clang loop unroll(full) - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; #pragma clang loop unroll(full) - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { +#pragma clang loop unroll(full) + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -83,12 +90,20 @@ template } threadgroup_barrier(mem_flags::mem_threadgroup); } + if (stride == 1) { #pragma clang loop unroll(full) - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; #pragma clang loop unroll(full) - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { +#pragma clang loop unroll(full) + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/Source/Cmlx/mlx-generated/logsumexp.cpp b/Source/Cmlx/mlx-generated/logsumexp.cpp index 9c092cb2..d3d4cf3d 100644 --- a/Source/Cmlx/mlx-generated/logsumexp.cpp +++ b/Source/Cmlx/mlx-generated/logsumexp.cpp @@ -92,8 +92,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; @@ -121,11 +121,8 @@ template } threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } )preamble"; diff --git a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal b/Source/Cmlx/mlx-generated/metal/arg_reduce.metal index 8c904de6..3cd95c52 100644 --- a/Source/Cmlx/mlx-generated/metal/arg_reduce.metal +++ b/Source/Cmlx/mlx-generated/metal/arg_reduce.metal @@ -80,9 +80,10 @@ template const constant size_t& ndim [[buffer(5)]], const constant int64_t& axis_stride [[buffer(6)]], const constant size_t& axis_size [[buffer(7)]], - uint gid [[thread_position_in_grid]], - uint lid [[thread_position_in_threadgroup]], - uint lsize [[threads_per_threadgroup]], + uint3 gid [[thread_position_in_grid]], + uint3 gsize [[threads_per_grid]], + uint3 lid [[thread_position_in_threadgroup]], + uint3 lsize [[threads_per_threadgroup]], uint simd_size [[threads_per_simdgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { @@ -104,17 +105,18 @@ template // Compute the input/output index. There is one beginning and one output for // the whole threadgroup. - auto in_idx = elem_to_loc(gid / lsize, shape, in_strides, ndim); - auto out_idx = elem_to_loc(gid / lsize, shape, out_strides, ndim); + int64_t row_idx = gid.y + static_cast(gsize.y) * gid.z; + auto in_idx = elem_to_loc(row_idx, shape, in_strides, ndim); + auto out_idx = elem_to_loc(row_idx, shape, out_strides, ndim); IndexValPair best{0, Op::init}; threadgroup IndexValPair local_data[32]; // Loop over the reduction axis in lsize*N_READS buckets - for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize); r++) { + for (uint r = 0; r < ceildiv(axis_size, N_READS * lsize.x); r++) { // Read the current value - uint32_t current_index = r * lsize * N_READS + lid * N_READS; + uint32_t current_index = r * lsize.x * N_READS + lid.x * N_READS; uint32_t offset = current_index; const device T* current_in = in + in_idx + current_index * axis_stride; T vals[N_READS]; @@ -144,7 +146,7 @@ template } // Read the appropriate value from local data and perform one simd reduction - uint simd_groups = ceildiv(lsize, simd_size); + uint simd_groups = ceildiv(lsize.x, simd_size); if (simd_lane_id < simd_groups) { best = local_data[simd_lane_id]; } @@ -154,7 +156,7 @@ template } // Finally write the output - if (lid == 0) { + if (lid.x == 0) { out[out_idx] = best.index; } } diff --git a/Source/Cmlx/mlx-generated/metal/binary.h b/Source/Cmlx/mlx-generated/metal/binary.h index 91a02c81..f1df8853 100644 --- a/Source/Cmlx/mlx-generated/metal/binary.h +++ b/Source/Cmlx/mlx-generated/metal/binary.h @@ -9,64 +9,121 @@ template c[index] = Op()(a[0], b[0]); } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[0], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[0], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, + constant uint& size, uint index [[thread_position_in_grid]]) { - c[index] = Op()(a[index], b[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[index + i] = Op()(a[index + i], b[index + i]); + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[0], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[0], b[offset + i]); + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[0]); + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - int64_t offset = index.x + grid_dim.x * int64_t(index.y); - c[offset] = Op()(a[offset], b[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + c[offset + i] = Op()(a[offset + i], b[offset + i]); + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/binary_ops.h b/Source/Cmlx/mlx-generated/metal/binary_ops.h index 4aaf2b4d..f4deb860 100644 --- a/Source/Cmlx/mlx-generated/metal/binary_ops.h +++ b/Source/Cmlx/mlx-generated/metal/binary_ops.h @@ -235,6 +235,13 @@ struct Power { template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/Source/Cmlx/mlx-generated/metal/binary_two.h b/Source/Cmlx/mlx-generated/metal/binary_two.h index 8f6b3392..4455e4ca 100644 --- a/Source/Cmlx/mlx-generated/metal/binary_two.h +++ b/Source/Cmlx/mlx-generated/metal/binary_two.h @@ -12,82 +12,151 @@ template d[index] = out[1]; } -template +template ::n> [[kernel]] void binary_sv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[0], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[0]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[0]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv( device const T* a, device const T* b, device U* c, device U* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - auto out = Op()(a[index], b[index]); - c[index] = out[0]; - d[index] = out[1]; + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[index + i], b[index + i]); + c[index + i] = out[0]; + d[index + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_sv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[0], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[0], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vs2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[0]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[0]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } -template +template ::n> [[kernel]] void binary_vv2( device const T* a, device const T* b, device U* c, device U* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - auto out = Op()(a[offset], b[offset]); - c[offset] = out[0]; - d[offset] = out[1]; + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } else { + for (int i = 0; i < N; ++i) { + auto out = Op()(a[offset + i], b[offset + i]); + c[offset + i] = out[0]; + d[offset + i] = out[1]; + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/copy.h b/Source/Cmlx/mlx-generated/metal/copy.h index b1367cf4..cf22347e 100644 --- a/Source/Cmlx/mlx-generated/metal/copy.h +++ b/Source/Cmlx/mlx-generated/metal/copy.h @@ -1,39 +1,77 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void copy_s( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[0]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant uint& size, uint index [[thread_position_in_grid]]) { - dst[index] = static_cast(src[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[index + i] = static_cast(src[index + i]); + } + } } -template +template ::n> [[kernel]] void copy_s2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[0]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[0]); + } + } } -template +template ::n> [[kernel]] void copy_v2( device const T* src [[buffer(0)]], device U* dst [[buffer(1)]], + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - dst[offset] = static_cast(src[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + dst[offset + i] = static_cast(src[offset + i]); + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h b/Source/Cmlx/mlx-generated/metal/fft/readwrite.h index 23231946..4459d36f 100644 --- a/Source/Cmlx/mlx-generated/metal/fft/readwrite.h +++ b/Source/Cmlx/mlx-generated/metal/fft/readwrite.h @@ -10,7 +10,7 @@ For many sizes, GPU FFTs are memory bandwidth bound so read/write performance is important. Where possible, we read 128 bits sequentially in each thread, -coalesced with accesses from adajcent threads for optimal performance. +coalesced with accesses from adjacent threads for optimal performance. We implement specialized reading/writing for: - FFT @@ -98,7 +98,7 @@ struct ReadWriter { } METAL_FUNC void load() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -121,7 +121,7 @@ struct ReadWriter { } METAL_FUNC void write() const { - int batch_idx = elem.x * grid.y * n; + size_t batch_idx = size_t(elem.x * grid.y) * n; short tg_idx = elem.y * grid.z + elem.z; short max_index = grid.y * n - 2; @@ -144,7 +144,7 @@ struct ReadWriter { // Padded IO for Bluestein's algorithm METAL_FUNC void load_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; @@ -161,7 +161,7 @@ struct ReadWriter { } METAL_FUNC void write_padded(int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length + elem.y * length; + size_t batch_idx = size_t(elem.x * grid.y) * length + elem.y * length; int fft_idx = elem.z; int m = grid.z; float2 inv_factor = {1.0f / n, -1.0f / n}; @@ -261,7 +261,7 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { - int batch_idx = elem.x * grid.y * n * 2 + elem.y * n * 2; + size_t batch_idx = size_t(elem.x * grid.y) * n * 2 + elem.y * n * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -283,7 +283,8 @@ template <> METAL_FUNC void ReadWriter::write() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; int grid_index = elem.x * grid.y + elem.y; @@ -317,7 +318,7 @@ template <> METAL_FUNC void ReadWriter::load_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -345,8 +346,8 @@ METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; @@ -397,7 +398,8 @@ METAL_FUNC bool ReadWriter::out_of_bounds() const { template <> METAL_FUNC void ReadWriter::load() const { short n_over_2 = (n / 2) + 1; - int batch_idx = elem.x * grid.y * n_over_2 * 2 + elem.y * n_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * n_over_2 * 2 + elem.y * n_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -458,8 +460,8 @@ METAL_FUNC void ReadWriter::load_padded( int n_over_2 = (n / 2) + 1; int length_over_2 = (length / 2) + 1; - int batch_idx = - elem.x * grid.y * length_over_2 * 2 + elem.y * length_over_2 * 2; + size_t batch_idx = + size_t(elem.x * grid.y) * length_over_2 * 2 + elem.y * length_over_2 * 2; threadgroup float2* seq_buf = buf + elem.y * n; // No out of bounds accesses on odd batch sizes @@ -503,7 +505,7 @@ template <> METAL_FUNC void ReadWriter::write_padded( int length, const device float2* w_k) const { - int batch_idx = elem.x * grid.y * length * 2 + elem.y * length * 2; + size_t batch_idx = size_t(elem.x * grid.y) * length * 2 + elem.y * length * 2; threadgroup float2* seq_buf = buf + elem.y * n + length - 1; int grid_index = elem.x * grid.y + elem.y; diff --git a/Source/Cmlx/mlx-generated/metal/hadamard.h b/Source/Cmlx/mlx-generated/metal/hadamard.h index 8f2d8cc1..d6c08f17 100644 --- a/Source/Cmlx/mlx-generated/metal/hadamard.h +++ b/Source/Cmlx/mlx-generated/metal/hadamard.h @@ -26,7 +26,7 @@ METAL_FUNC void radix_func(thread float* x) { } } -template +template [[kernel]] void hadamard_n( const device T* in [[buffer(0)]], device T* out [[buffer(1)]], @@ -46,18 +46,25 @@ template constexpr short logFinal = logN % logR; constexpr short final_radix = 1 << (logFinal); - int batch_idx = elem.x * N; - short i = elem.y; + int batch_idx = elem.y * N * stride + elem.z; + short i = elem.x; threadgroup T buf[N]; // Read values from device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - buf[index + r] = in[batch_idx + index + r]; + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + buf[index + r] = in[batch_idx + index + r]; + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + buf[j * num_threads + i] = in[batch_idx + (j * num_threads + i) * stride]; } } @@ -113,12 +120,20 @@ template } // Write values to device - STEEL_PRAGMA_UNROLL - for (short j = 0; j < max_radix / read_width; j++) { - short index = j * read_width * num_threads + i * read_width; + if (stride == 1) { STEEL_PRAGMA_UNROLL - for (short r = 0; r < read_width; r++) { - out[batch_idx + index + r] = T(buf[index + r] * scale); + for (short j = 0; j < max_radix / read_width; j++) { + short index = j * read_width * num_threads + i * read_width; + STEEL_PRAGMA_UNROLL + for (short r = 0; r < read_width; r++) { + out[batch_idx + index + r] = T(buf[index + r] * scale); + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < max_radix; j++) { + out[batch_idx + (j * num_threads + i) * stride] = + buf[j * num_threads + i]; } } } diff --git a/Source/Cmlx/mlx-generated/metal/layer_norm.metal b/Source/Cmlx/mlx-generated/metal/layer_norm.metal index 2a628d11..9a5c5d3d 100644 --- a/Source/Cmlx/mlx-generated/metal/layer_norm.metal +++ b/Source/Cmlx/mlx-generated/metal/layer_norm.metal @@ -9,7 +9,41 @@ using namespace metal; constant bool has_w [[function_constant(20)]]; -template +template +inline void initialize_buffer( + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + if (simd_group_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_lane_id + i] = 0; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +template +inline void threadgroup_sum( + thread float* x, + threadgroup float* xs, + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + for (int i = 0; i < N; i++) { + x[i] = simd_sum(x[i]); + } + if (simd_lane_id == 0) { + for (int i = 0; i < N; i++) { + xs[N * simd_group_id + i] = x[i]; + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int i = 0; i < N; i++) { + x[i] = xs[N * simd_lane_id + i]; + x[i] = simd_sum(x[i]); + } +} + +template [[kernel]] void layer_norm_single_row( const device T* x, const device T* w, @@ -23,90 +57,71 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - float thread_x[N_READS]; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + threadgroup float local_buffer[SIMD_SIZE] = {0}; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); + // Advance the pointers x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; + + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumx += thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; + + // Compute the normalizer + float normalizer = 0; + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + normalizer += thread_x[i] * thread_x[i]; + } + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs - out += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - out[i] = - w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + out[i] = w[w_stride * i] * static_cast(thread_x[i]) + b[b_stride * i]; } } } -template +template [[kernel]] void layer_norm_looped( const device T* x, const device T* w, @@ -121,71 +136,52 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { - float sumx = 0; - float sumx2 = 0; - constexpr int SIMD_SIZE = 32; - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; + threadgroup float local_buffer[SIMD_SIZE]; + initialize_buffer(local_buffer, simd_lane_id, simd_group_id); x += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; b += b_stride * lid * N_READS; + // Compute the mean + float mean = 0; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; - sumx2 += xi * xi; - sumx += xi; + mean += x[i + r]; } } } } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); + // Compute the normalizer + float normalizer = 0; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + float t = x[i + r] - mean; + normalizer += t * t; + } + } } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; + threadgroup_sum(&normalizer, local_buffer, simd_lane_id, simd_group_id); + normalizer = metal::precise::rsqrt(normalizer / axis_size + eps); // Write the outputs out += gid * size_t(axis_size) + lid * N_READS; @@ -208,7 +204,7 @@ template } } -template +template [[kernel]] void vjp_layer_norm_single_row( const device T* x, const device T* w, @@ -222,133 +218,96 @@ template uint lid [[thread_position_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the computation and accumulators - float thread_x[N_READS]; - float thread_w[N_READS]; - float thread_g[N_READS]; - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; - - constexpr int SIMD_SIZE = 32; + // Initialize the registers and threadgroup memory + float thread_x[N_READS] = {0}; + float thread_w[N_READS] = {0}; + float thread_g[N_READS] = {0}; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; + // Compute some variables for reading writing etc + const bool safe = lid * N_READS + N_READS <= axis_size; + const int n = axis_size - lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + // Read the inputs + if (safe) { for (int i = 0; i < N_READS; i++) { thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; + thread_w[i] = w[i * w_stride]; } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = x[i]; - thread_w[i] = w[i * w_stride]; - thread_g[i] = g[i]; - float wg = thread_w[i] * thread_g[i]; - sumx += thread_x[i]; - sumx2 += thread_x[i] * thread_x[i]; - sumwg += wg; - sumwgx += wg * thread_x[i]; - } + for (int i = 0; i < n; i++) { + thread_x[i] = x[i]; + thread_g[i] = g[i]; + thread_w[i] = w[i * w_stride]; } } - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; + // Compute the mean + float mean = 0; + for (int i = 0; i < N_READS; i++) { + mean += thread_x[i]; } - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; + // Compute the neccesary scaling factors using the mean + if (!safe) { + for (int i = n; i < N_READS; i++) { + thread_x[i] = mean; } } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; + for (int i = 0; i < N_READS; i++) { + thread_x[i] -= mean; + factors[meanwg] += thread_w[i] * thread_g[i]; + factors[meanwgxc] += thread_w[i] * thread_g[i] * thread_x[i]; + factors[normalizer2] += thread_x[i] * thread_x[i]; + } + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; gw += gid * size_t(axis_size) + lid * N_READS; - if (lid * N_READS + N_READS <= axis_size) { + if (safe) { for (int i = 0; i < N_READS; i++) { - thread_x[i] = (thread_x[i] - mean) * normalizer; + thread_x[i] *= normalizer; gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } else { - for (int i = 0; i < N_READS; i++) { - if ((lid * N_READS + i) < axis_size) { - thread_x[i] = (thread_x[i] - mean) * normalizer; - gx[i] = static_cast( - normalizer * (thread_w[i] * thread_g[i] - meanwg) - - thread_x[i] * meanwgxc * normalizer2); - if (has_w) { - gw[i] = static_cast(thread_g[i] * thread_x[i]); - } + for (int i = 0; i < n; i++) { + thread_x[i] *= normalizer; + gx[i] = static_cast( + normalizer * (thread_w[i] * thread_g[i] - factors[meanwg]) - + thread_x[i] * factors[meanwgxc] * factors[normalizer2]); + if (has_w) { + gw[i] = static_cast(thread_g[i] * thread_x[i]); } } } } -template +template [[kernel]] void vjp_layer_norm_looped( const device T* x, const device T* w, @@ -363,102 +322,69 @@ template uint lsize [[threads_per_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]]) { + constexpr int SIMD_SIZE = 32; + // Advance the input pointers x += gid * size_t(axis_size) + lid * N_READS; g += gid * size_t(axis_size) + lid * N_READS; w += w_stride * lid * N_READS; - // Allocate registers for the accumulators - float sumx = 0; - float sumx2 = 0; - float sumwg = 0; - float sumwgx = 0; - - constexpr int SIMD_SIZE = 32; - - threadgroup float local_sumx[SIMD_SIZE]; - threadgroup float local_sumx2[SIMD_SIZE]; - threadgroup float local_sumwg[SIMD_SIZE]; - threadgroup float local_sumwgx[SIMD_SIZE]; - threadgroup float local_mean[1]; - threadgroup float local_normalizer[1]; - threadgroup float local_meanwg[1]; - threadgroup float local_meanwgx[1]; + threadgroup float local_buffer[3 * SIMD_SIZE]; + initialize_buffer<3>(local_buffer, simd_lane_id, simd_group_id); + // Compute the mean + float mean = 0; + for (uint r = 0; r < axis_size; r += lsize * N_READS) { + if (r + lid * N_READS + N_READS <= axis_size) { + for (int i = 0; i < N_READS; i++) { + mean += x[i + r]; + } + } else { + for (int i = 0; i < N_READS; i++) { + if ((r + lid * N_READS + i) < axis_size) { + mean += x[i + r]; + } + } + } + } + threadgroup_sum(&mean, local_buffer, simd_lane_id, simd_group_id); + mean /= axis_size; + + // Compute the neccesary scaling factors using the mean + float factors[3] = {0}; + constexpr int meanwg = 0; + constexpr int meanwgxc = 1; + constexpr int normalizer2 = 2; for (uint r = 0; r < axis_size; r += lsize * N_READS) { if (r + lid * N_READS + N_READS <= axis_size) { for (int i = 0; i < N_READS; i++) { - float xi = x[i + r]; + float t = x[i + r] - mean; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; } } else { for (int i = 0; i < N_READS; i++) { if ((r + lid * N_READS + i) < axis_size) { - float xi = x[i + r]; + float t = x[i + r] - mean; float wi = w[(i + r) * w_stride]; float gi = g[i + r]; float wg = wi * gi; - sumx += xi; - sumx2 += xi * xi; - sumwg += wg; - sumwgx += wg * xi; + factors[meanwg] += wg; + factors[meanwgxc] += wg * t; + factors[normalizer2] += t * t; } } } } - - sumx = simd_sum(sumx); - sumx2 = simd_sum(sumx2); - sumwg = simd_sum(sumwg); - sumwgx = simd_sum(sumwgx); - - // Initialize shared memory - if (simd_group_id == 0) { - local_sumx[simd_lane_id] = 0; - local_sumx2[simd_lane_id] = 0; - local_sumwg[simd_lane_id] = 0; - local_sumwgx[simd_lane_id] = 0; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Write simd accumulations into shared memory - if (simd_lane_id == 0) { - local_sumx[simd_group_id] = sumx; - local_sumx2[simd_group_id] = sumx2; - local_sumwg[simd_group_id] = sumwg; - local_sumwgx[simd_group_id] = sumwgx; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - // Accumulate over simd groups - if (simd_group_id == 0) { - sumx = simd_sum(local_sumx[simd_lane_id]); - sumx2 = simd_sum(local_sumx2[simd_lane_id]); - sumwg = simd_sum(local_sumwg[simd_lane_id]); - sumwgx = simd_sum(local_sumwgx[simd_lane_id]); - if (simd_lane_id == 0) { - float mean = sumx / axis_size; - float variance = sumx2 / axis_size - mean * mean; - - local_mean[0] = mean; - local_normalizer[0] = metal::precise::rsqrt(variance + eps); - local_meanwg[0] = sumwg / axis_size; - local_meanwgx[0] = sumwgx / axis_size; - } - } - threadgroup_barrier(mem_flags::mem_threadgroup); - - float mean = local_mean[0]; - float normalizer = local_normalizer[0]; - float meanwg = local_meanwg[0]; - float meanwgxc = local_meanwgx[0] - meanwg * mean; - float normalizer2 = normalizer * normalizer; + threadgroup_sum<3>(factors, local_buffer, simd_lane_id, simd_group_id); + factors[meanwg] /= axis_size; + factors[meanwgxc] /= axis_size; + factors[normalizer2] = 1 / (factors[normalizer2] / axis_size + eps); + float normalizer = metal::precise::sqrt(factors[normalizer2]); // Write the outputs gx += gid * size_t(axis_size) + lid * N_READS; @@ -470,7 +396,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } @@ -482,7 +409,8 @@ template float wi = w[(i + r) * w_stride]; float gi = g[i + r]; gx[i + r] = static_cast( - normalizer * (wi * gi - meanwg) - xi * meanwgxc * normalizer2); + normalizer * (wi * gi - factors[meanwg]) - + xi * factors[meanwgxc] * factors[normalizer2]); if (has_w) { gw[i + r] = static_cast(gi * xi); } diff --git a/Source/Cmlx/mlx-generated/metal/logsumexp.h b/Source/Cmlx/mlx-generated/metal/logsumexp.h index b6898e31..c746050b 100644 --- a/Source/Cmlx/mlx-generated/metal/logsumexp.h +++ b/Source/Cmlx/mlx-generated/metal/logsumexp.h @@ -103,8 +103,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; @@ -134,10 +134,7 @@ template threadgroup_barrier(mem_flags::mem_threadgroup); normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_group_id == 0) { - normalizer = simd_sum(local_normalizer[simd_lane_id]); - if (simd_lane_id == 0) { - out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); - } + if (lid == 0) { + out[gid] = isinf(maxval) ? T(maxval) : T(log(normalizer) + maxval); } } diff --git a/Source/Cmlx/mlx-generated/metal/quantized.h b/Source/Cmlx/mlx-generated/metal/quantized.h index b2b0d8d8..fea6f146 100644 --- a/Source/Cmlx/mlx-generated/metal/quantized.h +++ b/Source/Cmlx/mlx-generated/metal/quantized.h @@ -14,11 +14,23 @@ using namespace metal; MLX_MTL_CONST int SIMD_SIZE = 32; MLX_MTL_CONST int QUAD_SIZE = 4; +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} + +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} + template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -57,6 +69,21 @@ inline U load_vector(const device T* x, thread U* x_thread) { } } + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -80,8 +107,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; @@ -121,6 +149,21 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { } } + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } + else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -153,8 +196,9 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -199,6 +243,26 @@ inline U qdot( } } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; @@ -234,8 +298,9 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; @@ -280,6 +345,26 @@ inline U qdot_safe( } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; @@ -310,8 +395,9 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; @@ -348,8 +434,31 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } + } - } else if (bits == 6) { + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; @@ -375,8 +484,9 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { @@ -416,11 +526,26 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } + else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; w += 3 * i; - w_local[0] = (w[0] & 0x3f) * scale + bias; w_local[1] = (((w[0] >> 6) & 0x03) + ((w[1] & 0x0f) << 2)) * scale + bias; w_local[2] = (((w[1] >> 4) & 0x0f) + ((w[2] & 0x03) << 4)) * scale + bias; @@ -452,11 +577,12 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); - MLX_MTL_CONST short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - MLX_MTL_CONST short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + MLX_MTL_CONST short pack_factor = get_pack_factor(); + MLX_MTL_CONST short bytes_per_pack = get_bytes_per_pack(); MLX_MTL_CONST short BCOLS_PACKED = BCOLS / pack_factor; MLX_MTL_CONST short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -632,12 +758,11 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -700,12 +825,12 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -857,8 +982,9 @@ METAL_FUNC void qvm_impl( uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; @@ -981,9 +1107,10 @@ METAL_FUNC void qmm_t_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -1008,11 +1135,11 @@ METAL_FUNC void qmm_t_impl( auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -1106,11 +1233,11 @@ METAL_FUNC void qmm_n_impl( constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); + constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; // Instantiate the appropriate BlockMMA and Loader using mma_t = mlx::steel:: @@ -1132,11 +1259,11 @@ METAL_FUNC void qmm_n_impl( // Set the block const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; // Make the x loader and mma operation const short num_els = min(BM, M - y_row); @@ -2120,11 +2247,10 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel::BlockMMA< T, @@ -2305,13 +2431,13 @@ template constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, @@ -2354,8 +2480,8 @@ template biases[gindex] = static_cast(bias); } - // We accumulate 3 bytes worth for 3/6 bit so we need a uint32_t - uint32_t output = 0; + using OutType = metal::conditional_t; + OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { @@ -2363,27 +2489,35 @@ template if (bits == 8) { output = val; } else { - output += val << (bits * (i % packs_per_int)); + output |= val << (bits * (i % pack_factor)); } - if (packs_per_int < values_per_reduce && - i % packs_per_int == packs_per_int - 1) { - out[out_index + i / packs_per_int] = output; + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); - output += sval << (bits * (j * values_per_reduce + i)); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { - if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; @@ -2399,12 +2533,11 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * packs_per_int; + size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; @@ -2421,7 +2554,16 @@ template out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; - + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; @@ -2431,7 +2573,7 @@ template } else { uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { + for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; diff --git a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h index c8973429..936d75bb 100644 --- a/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h +++ b/Source/Cmlx/mlx-generated/metal/reduction/reduce_row.h @@ -224,7 +224,7 @@ template < if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { // Simple loop over non_row_reductions and reduce the row in the thread. - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { diff --git a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h b/Source/Cmlx/mlx-generated/metal/sdpa_vector.h index c4c0f645..8258e9c1 100644 --- a/Source/Cmlx/mlx-generated/metal/sdpa_vector.h +++ b/Source/Cmlx/mlx-generated/metal/sdpa_vector.h @@ -56,9 +56,9 @@ template const int head_idx = tid.x; const int q_seq_idx = tid.y; const int kv_head_idx = head_idx / gqa_factor; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; queries += q_offset * D + simd_lid * qk_per_thread; keys += kv_head_idx * k_head_stride + simd_gid * k_seq_stride + simd_lid * qk_per_thread; @@ -213,9 +213,9 @@ template const int block_idx = tid.z; const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int o_offset = tpg.x * q_seq_idx + head_idx; + const int o_offset = head_idx * tpg.y + q_seq_idx; const int q_offset = - query_transposed ? o_offset : head_idx * tpg.y + q_seq_idx; + query_transposed ? tpg.x * q_seq_idx + head_idx : o_offset; const int kv_head_idx = head_idx / gqa_factor; queries += q_offset * D + simd_lid * qk_per_thread; @@ -358,8 +358,8 @@ template // Adjust positions const int head_idx = tid.x; const int q_seq_idx = tid.y; - const int n_heads = tpg.x; - const int q_offset = n_heads * q_seq_idx + head_idx; + const int q_offset = head_idx * tpg.y + q_seq_idx; + ; partials += q_offset * blocks * D + simd_gid * D + simd_lid * elem_per_thread; sums += q_offset * blocks; maxs += q_offset * blocks; diff --git a/Source/Cmlx/mlx-generated/metal/softmax.h b/Source/Cmlx/mlx-generated/metal/softmax.h index b36b73bd..6ea4ac73 100644 --- a/Source/Cmlx/mlx-generated/metal/softmax.h +++ b/Source/Cmlx/mlx-generated/metal/softmax.h @@ -128,8 +128,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h index 2e27ea06..34d5bf58 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/kernels/steel_attention.h @@ -95,7 +95,7 @@ template < Q += tidl.z * params->Q_strides[0] + // Batch tidl.y * params->Q_strides[1] + // Head - tidl.x * BQ * params->Q_strides[2]; // Seqeunce + tidl.x * BQ * params->Q_strides[2]; // Sequence ulong kv_head_idx = int(tid.y) / params->gqa_factor; K += tidl.z * params->K_strides[0] + // Batch @@ -106,7 +106,7 @@ template < O += tidl.z * params->O_strides[0] + // Batch tidl.y * params->O_strides[1] + // Head - tidl.x * BQ * params->O_strides[2]; // Seqeunce + tidl.x * BQ * params->O_strides[2]; // Sequence if (has_mask) { mask += tidl.z * mask_params->M_strides[0] + // Batch diff --git a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h b/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h index 75d695e6..3b7c5166 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h +++ b/Source/Cmlx/mlx-generated/metal/steel/attn/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); @@ -240,7 +240,7 @@ struct BlockLoaderT { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h index 9261b871..c92fcf36 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/kernels/steel_conv_general.h @@ -2,6 +2,8 @@ #include "../../../steel/conv/loaders/loader_general.h" +constant bool align_C [[function_constant(200)]]; + template < typename T, int BM, @@ -118,30 +120,65 @@ implicit_gemm_conv_2d_general( // Prepare threadgroup mma operation mma_t mma_op(simd_gid, simd_lid); - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - // Load elements into threadgroup - loader_a.load_unsafe(); - loader_b.load_unsafe(); + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // Load elements into threadgroup + loader_a.load_unsafe(); + loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - // Multiply and accumulate threadgroup elements - mma_op.mma(As, Bs); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); - // Prepare for next iteration - loader_a.next(); - loader_b.next(); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + // Load elements into threadgroup + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + // Multiply and accumulate threadgroup elements + mma_op.mma(As, Bs); + // Prepare for next iteration + loader_a.next(); + loader_b.next(); + } } threadgroup_barrier(mem_flags::mem_none); // Store results to device memory { - // Adjust for simdgroup and thread locatio + // Adjust for simdgroup and thread location int offset_m = c_row + mma_op.sm; int offset_n = c_col + mma_op.sn; C += offset_n; diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h index 85a6d134..22eebe03 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_l.h @@ -381,6 +381,7 @@ struct Conv2DWeightBlockLoader { const constant MLXConvParams<2>* params; int weight_hw; + int weight_step; const int read_n; const bool do_read; @@ -402,6 +403,7 @@ struct Conv2DWeightBlockLoader { src(src_ + bi * src_ld + bj), params(params_), weight_hw(0), + weight_step(params->C / params->groups), read_n(offsets.y + bi), do_read(read_n + n_rows * TROWS <= gemm_params_->N) {} @@ -435,15 +437,15 @@ struct Conv2DWeightBlockLoader { /* Iteration helper */ METAL_FUNC void next() { if (++weight_hw < (params->wS[1] * params->wS[0])) { - src += params->wt_strides[2]; + src += weight_step; return; } weight_hw = 0; - src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2]; + src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step; } }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h index 2f12535f..b2cdea01 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_channel_n.h @@ -272,7 +272,7 @@ struct Conv2DWeightBlockLoaderSmallChannels { return; } - const device T* curr_src = src + weight_hw * params->wt_strides[2]; + const device T* curr_src = src + weight_hw * (params->C / params->groups); if (BN != 8 || do_read) { STEEL_PRAGMA_UNROLL @@ -316,4 +316,4 @@ struct Conv2DWeightBlockLoaderSmallChannels { }; } // namespace steel -} // namespace mlx \ No newline at end of file +} // namespace mlx diff --git a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h index 3f5be762..9043a3c4 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h +++ b/Source/Cmlx/mlx-generated/metal/steel/conv/loaders/loader_general.h @@ -137,6 +137,52 @@ struct Conv2DInputBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + STEEL_PRAGMA_UNROLL + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + // Find bounds + int n = read_n[i]; + + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + + // Read from input if in bounds + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + + // Zero pad otherwise + else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; @@ -262,6 +308,55 @@ struct Conv2DWeightBlockLoaderGeneral { } } + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + + if ((start_row + BN <= params->O)) { + STEEL_PRAGMA_UNROLL + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { + STEEL_PRAGMA_UNROLL + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + /* Iteration helper */ METAL_FUNC void next() { weight_w += jump_params->f_wgt_jump_w; diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h index add495d9..85830872 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/kernels/steel_gemm_fused.h @@ -33,8 +33,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h b/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h index 1846e26d..cc79de86 100644 --- a/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h +++ b/Source/Cmlx/mlx-generated/metal/steel/gemm/loader.h @@ -113,7 +113,7 @@ struct BlockLoader { tmp_val[j] = src[(tmp_idx[j] ? i * src_ld + j : 0)]; } - // Zero out uneeded values + // Zero out unneeded values STEEL_PRAGMA_UNROLL for (short j = 0; j < vec_size; j++) { tmp_val[j] = tmp_idx[j] ? tmp_val[j] : T(0); diff --git a/Source/Cmlx/mlx-generated/metal/ternary.h b/Source/Cmlx/mlx-generated/metal/ternary.h index 4b3adcc8..570f5e4d 100644 --- a/Source/Cmlx/mlx-generated/metal/ternary.h +++ b/Source/Cmlx/mlx-generated/metal/ternary.h @@ -1,25 +1,44 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } } template diff --git a/Source/Cmlx/mlx-generated/metal/unary.h b/Source/Cmlx/mlx-generated/metal/unary.h index 69828599..649ba7f2 100644 --- a/Source/Cmlx/mlx-generated/metal/unary.h +++ b/Source/Cmlx/mlx-generated/metal/unary.h @@ -1,21 +1,40 @@ // Copyright © 2024 Apple Inc. -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } } template < diff --git a/Source/Cmlx/mlx-generated/metal/utils.h b/Source/Cmlx/mlx-generated/metal/utils.h index 8fd67b89..28840a5c 100644 --- a/Source/Cmlx/mlx-generated/metal/utils.h +++ b/Source/Cmlx/mlx-generated/metal/utils.h @@ -15,6 +15,14 @@ typedef half float16_t; +// Work per thread values for different types. The values here are expected to +// match get_work_per_thread in mlx/backend/metal/utils.h +template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; + /////////////////////////////////////////////////////////////////////////////// // Type limits utils /////////////////////////////////////////////////////////////////////////////// diff --git a/Source/Cmlx/mlx-generated/quantized.cpp b/Source/Cmlx/mlx-generated/quantized.cpp index da1a4930..6bb19208 100644 --- a/Source/Cmlx/mlx-generated/quantized.cpp +++ b/Source/Cmlx/mlx-generated/quantized.cpp @@ -8,11 +8,21 @@ constant bool align_K [[function_constant(202)]]; using namespace metal; static constant constexpr const int SIMD_SIZE = 32; static constant constexpr const int QUAD_SIZE = 4; +template +inline constexpr short get_pack_factor() { + return (bits == 3 || bits == 5) ? 8 : (bits == 6 ? 4 : wsize / bits); +} +template +inline constexpr short get_bytes_per_pack() { + constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; + return power_of_2_bits ? (wsize / 8) : (bits == 5 ? 5 : 3); +} template inline U load_vector(const device T* x, thread U* x_thread) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < values_per_thread; i += 4) { @@ -46,6 +56,20 @@ inline U load_vector(const device T* x, thread U* x_thread) { x_thread[i + 3] = x[i + 3] / 4096.0f; } } + else if (bits == 5) { + for (int i = 0; i < values_per_thread; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } else if (bits == 6) { for (int i = 0; i < values_per_thread; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -66,8 +90,9 @@ inline U load_vector(const device T* x, thread U* x_thread) { template inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U sum = 0; if (bits == 2) { for (int i = 0; i < N; i += 4) { @@ -101,6 +126,20 @@ inline U load_vector_safe(const device T* x, thread U* x_thread, int N) { x_thread[i + 3] = x[i + 3] / 4096.0f; } } + else if (bits == 5) { + for (int i = 0; i < N; i += 8) { + sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3] + x[i + 4] + x[i + 5] + + x[i + 6] + x[i + 7]; + x_thread[i] = x[i]; + x_thread[i + 1] = x[i + 1] / 32.0f; + x_thread[i + 2] = x[i + 2] / 4.0f; + x_thread[i + 3] = x[i + 3] / 128.0f; + x_thread[i + 4] = x[i + 4] / 16.0f; + x_thread[i + 5] = x[i + 5] / 2.0f; + x_thread[i + 6] = x[i + 6] / 64.0f; + x_thread[i + 7] = x[i + 7] / 8.0f; + } + } else if (bits == 6) { for (int i = 0; i < N; i += 4) { sum += x[i] + x[i + 1] + x[i + 2] + x[i + 3]; @@ -129,8 +168,9 @@ inline U qdot( U bias, U sum) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (values_per_thread / 4); i++) { @@ -167,6 +207,24 @@ inline U qdot( x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { x_thread += 4 * i; @@ -195,8 +253,9 @@ inline U qdot_safe( U sum, int N) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); U accum = 0; if (bits == 2) { for (int i = 0; i < (N / 4); i++) { @@ -233,6 +292,24 @@ inline U qdot_safe( x_thread[4 * i + 3] * (ws[i] & 0xf000)); } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + x_thread += 8 * i; + w += 5 * i; + accum += (w[0] & 0x1f) * x_thread[0]; + accum += (w[0] & 0xe0) * x_thread[1]; + accum += (w[1] & 0x3) * (x_thread[1] * 256.0f); + accum += (w[1] & 0x7c) * x_thread[2]; + accum += (w[1] & 0x80) * x_thread[3]; + accum += (w[2] & 0xf) * (x_thread[3] * 256.0f); + accum += (w[2] & 0xf0) * x_thread[4]; + accum += (w[3] & 0x1) * (x_thread[4] * 256.0f); + accum += (w[3] & 0x3e) * x_thread[5]; + accum += (w[3] & 0xc0) * x_thread[6]; + accum += (w[4] & 0x7) * (x_thread[6] * 256.0f); + accum += (w[4] & 0xf8) * x_thread[7]; + } + } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { x_thread += 4 * i; @@ -256,8 +333,9 @@ template inline void qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = {scale, scale / 4.0f, scale / 16.0f, scale / 64.0f}; for (int i = 0; i < (values_per_thread / 4); i++) { @@ -290,7 +368,29 @@ qouter(const thread uint8_t* w, U x, U scale, U bias, thread U* result) { result[2 * i] += x * (s[0] * (w[i] & 0x0f) + bias); result[2 * i + 1] += x * (s[1] * (w[i] & 0xf0) + bias); } - } else if (bits == 6) { + } + else if (bits == 5) { + for (int i = 0; i < (values_per_thread / 8); i++) { + uint8_t w0 = w[5 * i]; + uint8_t w1 = w[5 * i + 1]; + uint8_t w2 = w[5 * i + 2]; + uint8_t w3 = w[5 * i + 3]; + uint8_t w4 = w[5 * i + 4]; + result[8 * i] += x * ((w0 & 0x1f) * scale + bias); + result[8 * i + 1] += + x * ((((w0 & 0xe0) >> 5) + ((w1 & 0x3) << 3)) * scale + bias); + result[8 * i + 2] += x * (((w1 & 0x7c) >> 2) * scale + bias); + result[8 * i + 3] += + x * ((((w1 & 0x80) >> 7) + ((w2 & 0xf) << 1)) * scale + bias); + result[8 * i + 4] += + x * ((((w2 & 0xf0) >> 4) + ((w3 & 0x1) << 4)) * scale + bias); + result[8 * i + 5] += x * (((w3 & 0x3e) >> 1) * scale + bias); + result[8 * i + 6] += + x * ((((w3 & 0xc0) >> 6) + ((w4 & 0x7) << 2)) * scale + bias); + result[8 * i + 7] += x * (((w4 & 0xf8) >> 3) * scale + bias); + } + } + else if (bits == 6) { for (int i = 0; i < (values_per_thread / 4); i++) { uint8_t w0 = w[3 * i]; uint8_t w1 = w[3 * i + 1]; @@ -313,8 +413,9 @@ template inline void dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); if (bits == 2) { U s[4] = { scale, @@ -349,6 +450,20 @@ dequantize(const device uint8_t* w, U scale, U bias, threadgroup U* w_local) { w_local[2 * i + 1] = s[1] * (w[i] & 0xf0) + bias; } } + else if (bits == 5) { + for (int i = 0; i < (N / 8); i++) { + w_local += 8 * i; + w += 5 * i; + w_local[0] = (w[0] & 0x1f) * scale + bias; + w_local[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + w_local[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + w_local[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + w_local[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + w_local[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + w_local[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + w_local[7] = ((w[4] & 0xf8) >> 3) * scale + bias; + } + } else if (bits == 6) { for (int i = 0; i < (N / 4); i++) { w_local += 4 * i; @@ -382,10 +497,11 @@ struct QuantizedBlockLoader { group_size % BCOLS == 0, "The group size should be divisible by the columns"); static_assert( - bits == 2 || bits == 3 || bits == 4 || bits == 6 || bits == 8, - "Template undefined for bits not in {2, 3, 4, 6, 8}"); - static constant constexpr const short pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - static constant constexpr const short bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; + bits == 2 || bits == 3 || bits == 4 || bits == 5 || bits == 6 || + bits == 8, + "Template undefined for bits not in {2, 3, 4, 5, 6, 8}"); + static constant constexpr const short pack_factor = get_pack_factor(); + static constant constexpr const short bytes_per_pack = get_bytes_per_pack(); static constant constexpr const short BCOLS_PACKED = BCOLS / pack_factor; static constant constexpr const short n_reads = (BCOLS_PACKED * BROWS < tgp_size) ? 1 : (BCOLS_PACKED * BROWS) / tgp_size; @@ -539,12 +655,11 @@ METAL_FUNC void qmv_fast_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int packs_per_thread = bits == 2 ? 1 : 2; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -595,12 +710,11 @@ METAL_FUNC void qmv_impl( uint3 tid [[threadgroup_position_in_grid]], uint simd_gid [[simdgroup_index_in_threadgroup]], uint simd_lid [[thread_index_in_simdgroup]]) { - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; constexpr int results_per_simdgroup = 4; constexpr int packs_per_thread = 1; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 4 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_thread = pack_factor * packs_per_thread; constexpr int block_size = values_per_thread * SIMD_SIZE; constexpr int scale_step_per_thread = group_size / values_per_thread; @@ -727,8 +841,8 @@ METAL_FUNC void qvm_impl( uint simd_lid [[thread_index_in_simdgroup]]) { constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; constexpr int num_simdgroups = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 32 / bits; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int tn = 32 / pack_factor; constexpr int block_size = SIMD_SIZE; using W_T = @@ -833,9 +947,9 @@ METAL_FUNC void qmm_t_impl( (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); - constexpr int bytes_per_pack = (bits == 3 || bits == 6) ? 3 : 1; using mma_t = mlx::steel:: BlockMMA; using loader_x_t = @@ -854,11 +968,11 @@ METAL_FUNC void qmm_t_impl( const int y_row = tid.y * BM; const int y_col = tid.x * BN; auto wl = (const device uint8_t*)w; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * K_w; scales += y_col * K_g; biases += y_col * K_g; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; const short num_els = min(BM, M - y_row); const short num_outs = min(BN, N - y_col); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); @@ -943,11 +1057,10 @@ METAL_FUNC void qmm_n_impl( (void)lid; constexpr int WM = 2; constexpr int WN = 2; - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel:: BlockMMA; using loader_x_t = mlx::steel:: @@ -964,11 +1077,11 @@ METAL_FUNC void qmm_n_impl( auto wl = (const device uint8_t*)w; const int y_row = tid.y * BM; const int y_col = tid.x * BN; - x += y_row * K; + x += y_row * static_cast(K); wl += y_col * bytes_per_pack / pack_factor; scales += y_col / group_size; biases += y_col / group_size; - y += y_row * N + y_col; + y += y_row * static_cast(N) + y_col; const short num_els = min(BM, M - y_row); loader_x_t loader_x(x, K, Xs, simd_gid, simd_lid); loader_w_t loader_w(wl, scales, biases, N, Ws, simd_gid, simd_lid); @@ -1897,11 +2010,10 @@ template < uint3 tid [[threadgroup_position_in_grid]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint simd_lane_id [[thread_index_in_simdgroup]]) { - constexpr int pack_factor = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int BK_padded = (BK + 16 / sizeof(T)); constexpr int BN_padded = (BN + 16 / sizeof(T)); - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; using mma_t = mlx::steel::BlockMMA< T, T, @@ -2052,13 +2164,13 @@ template constexpr float eps = 1e-7; constexpr int simd_size = 32; constexpr float n_bins = (1 << bits) - 1; - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); constexpr int values_per_reduce = group_size / simd_size; - constexpr int writes_per_reduce = packs_per_int / values_per_reduce; + constexpr int writes_per_reduce = pack_factor / values_per_reduce; constexpr int writes_per_pack = - writes_per_reduce > 1 ? 1 : values_per_reduce / packs_per_int; + writes_per_reduce > 1 ? 1 : values_per_reduce / pack_factor; constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; static_assert( group_size % simd_size == 0, "Group size must be divisible by simd size."); @@ -2092,33 +2204,42 @@ template scales[gindex] = static_cast(scale); biases[gindex] = static_cast(bias); } - uint32_t output = 0; + using OutType = metal::conditional_t; + OutType output = 0; #pragma clang loop unroll(full) for (int i = 0; i < values_per_reduce; i++) { uint8_t val = min(round((w_thread[i] - bias) / scale), n_bins); if (bits == 8) { output = val; } else { - output += val << (bits * (i % packs_per_int)); + output |= val << (bits * (i % pack_factor)); } - if (packs_per_int < values_per_reduce && - i % packs_per_int == packs_per_int - 1) { - out[out_index + i / packs_per_int] = output; + if (pack_factor < values_per_reduce && i % pack_factor == pack_factor - 1) { + out[out_index + i / pack_factor] = output; output = 0; } else { #pragma clang loop unroll(full) for (int j = 1; j < writes_per_reduce; j++) { uint8_t sval = simd_shuffle_down(val, j); - output += sval << (bits * (j * values_per_reduce + i)); + output |= static_cast(sval) + << (bits * (j * values_per_reduce + i)); } } } if (bits == 3 || bits == 6) { - if (in_index % packs_per_int == 0 && out_index % bytes_per_pack == 0) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { out[out_index] = output & 0xff; out[out_index + 1] = (output & 0xff00) >> 8; out[out_index + 2] = (output & 0xff0000) >> 16; } + } else if (bits == 5) { + if (in_index % pack_factor == 0 && out_index % bytes_per_pack == 0) { + out[out_index] = output & 0xff; + out[out_index + 1] = (output & 0xff00) >> 8; + out[out_index + 2] = (output & 0xff0000) >> 16; + out[out_index + 3] = (output & 0xff000000) >> 24; + out[out_index + 4] = (output & 0xff00000000) >> 32; + } } else { if (writes_per_reduce > 0 && out_index % writes_per_reduce == 0) { out[out_index / writes_per_reduce] = output; @@ -2133,11 +2254,10 @@ template device T* out [[buffer(3)]], uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - constexpr int packs_per_int = bits == 3 ? 8 : bits == 6 ? 4 : 8 / bits; - constexpr int power_of_2_bits = (bits & (bits - 1)) == 0; - constexpr int bytes_per_pack = power_of_2_bits ? 1 : 3; + constexpr int pack_factor = get_pack_factor(); + constexpr int bytes_per_pack = get_bytes_per_pack(); size_t offset = index.x + grid_dim.x * size_t(index.y); - size_t oindex = offset * packs_per_int; + size_t oindex = offset * pack_factor; size_t gindex = oindex / group_size; T scale = scales[gindex]; T bias = biases[gindex]; @@ -2152,6 +2272,16 @@ template out[5] = (((w[1] & 0x80) >> 7) + ((w[2] & 0x3) << 1)) * scale + bias; out[6] = ((w[2] & 0x1c) >> 2) * scale + bias; out[7] = ((w[2] & 0xe0) >> 5) * scale + bias; + } else if (bits == 5) { + w += offset * bytes_per_pack; + out[0] = (w[0] & 0x1f) * scale + bias; + out[1] = (((w[0] & 0xe0) >> 5) + ((w[1] & 0x3) << 3)) * scale + bias; + out[2] = ((w[1] & 0x7c) >> 2) * scale + bias; + out[3] = (((w[1] & 0x80) >> 7) + ((w[2] & 0xf) << 1)) * scale + bias; + out[4] = (((w[2] & 0xf0) >> 4) + ((w[3] & 0x1) << 4)) * scale + bias; + out[5] = ((w[3] & 0x3e) >> 1) * scale + bias; + out[6] = (((w[3] & 0xc0) >> 6) + ((w[4] & 0x7) << 2)) * scale + bias; + out[7] = ((w[4] & 0xf8) >> 3) * scale + bias; } else if (bits == 6) { w += offset * bytes_per_pack; out[0] = (w[0] & 0x3f) * scale + bias; @@ -2161,7 +2291,7 @@ template } else { uint val = w[offset]; #pragma clang loop unroll(full) - for (int i = 0; i < packs_per_int; i++) { + for (int i = 0; i < pack_factor; i++) { uint8_t d; if (bits == 2) { d = (val >> (bits * i)) & 0x03; diff --git a/Source/Cmlx/mlx-generated/reduce.cpp b/Source/Cmlx/mlx-generated/reduce.cpp index 6785affb..ac05030e 100644 --- a/Source/Cmlx/mlx-generated/reduce.cpp +++ b/Source/Cmlx/mlx-generated/reduce.cpp @@ -574,7 +574,7 @@ template < int blocks = IdxT(row_size) / N_READS; int extra = IdxT(row_size) % N_READS; if ((non_row_reductions < 32 && row_size <= 8) || non_row_reductions <= 8) { - IdxT out_idx = tid.x + tsize.y * IdxT(tid.y); + IdxT out_idx = tid.x + tsize.x * IdxT(tid.y); in += elem_to_loc(out_idx, shape, strides, ndim); for (uint r = 0; r < non_row_reductions; r++) { row = in + loop.location(); diff --git a/Source/Cmlx/mlx-generated/scan.cpp b/Source/Cmlx/mlx-generated/scan.cpp index 05cc891e..76c654f6 100644 --- a/Source/Cmlx/mlx-generated/scan.cpp +++ b/Source/Cmlx/mlx-generated/scan.cpp @@ -210,6 +210,13 @@ struct Power { } template <> complex64_t operator()(complex64_t x, complex64_t y) { + if (x.real == 0 && x.imag == 0) { + if (metal::isnan(y.real) || metal::isnan(y.imag)) { + auto nan = metal::numeric_limits::quiet_NaN(); + return {nan, nan}; + } + return {0.0, 0.0}; + } auto x_theta = metal::atan2(x.imag, x.real); auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag); auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta); diff --git a/Source/Cmlx/mlx-generated/softmax.cpp b/Source/Cmlx/mlx-generated/softmax.cpp index 8761da62..60f3e2ad 100644 --- a/Source/Cmlx/mlx-generated/softmax.cpp +++ b/Source/Cmlx/mlx-generated/softmax.cpp @@ -112,8 +112,8 @@ template } } else { for (int i = 0; i < N_READS; i++) { - vals[i] = (offset + i < axis_size) ? AccT(in[offset + i]) - : Limits::finite_min; + vals[i] = + (offset + i < axis_size) ? AccT(in[offset + i]) : Limits::min; } } prevmax = maxval; diff --git a/Source/Cmlx/mlx-generated/steel_conv_general.cpp b/Source/Cmlx/mlx-generated/steel_conv_general.cpp index 98e34d93..aa3d00ff 100644 --- a/Source/Cmlx/mlx-generated/steel_conv_general.cpp +++ b/Source/Cmlx/mlx-generated/steel_conv_general.cpp @@ -89,6 +89,42 @@ struct Conv2DInputBlockLoaderGeneral { } } else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = T(0); + } + } + } + } + METAL_FUNC void load_safe(const short remaining_k) const { +#pragma clang loop unroll(full) + for (short i = 0, is = 0; i < n_rows; ++i, is += TROWS) { + int n = read_n[i]; + int h_flip = params->flip ? params->wS[0] - weight_h - 1 : weight_h; + int w_flip = params->flip ? params->wS[1] - weight_w - 1 : weight_w; + int ih_dil = read_ih[i] + h_flip * params->kdil[0]; + int iw_dil = read_iw[i] + w_flip * params->kdil[1]; + int ih = ih_dil / params->idil[0]; + int iw = iw_dil / params->idil[1]; + size_t offset = ih * params->in_strides[1] + iw * params->in_strides[2]; + if ((n < params->N) && (ih_dil >= 0 && ih < params->iS[0]) && + (iw_dil >= 0 && iw < params->iS[1])) { + if (bj + vec_size <= remaining_k) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; ++j) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } + } else { + for (short j = 0; j < vec_size; ++j) { + if (bj + j < remaining_k) { + dst[is * dst_ld + j] = (src[i])[offset + j]; + } else { + dst[is * dst_ld + j] = T(0); + } + } + } + } + else { #pragma clang loop unroll(full) for (short j = 0; j < vec_size; ++j) { dst[is * dst_ld + j] = T(0); @@ -184,6 +220,53 @@ struct Conv2DWeightBlockLoaderGeneral { dst[i * dst_ld + j] = curr_src[i * src_ld + j]; } } else { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } + METAL_FUNC void load_safe(const short remaining_k) const { + const device T* curr_src = src + weight_h * params->wt_strides[1] + + weight_w * params->wt_strides[2]; + if ((start_row + BN <= params->O)) { +#pragma clang loop unroll(full) + for (short i = 0; i < BN; i += TROWS) { + if (bj + vec_size <= remaining_k) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } + } else { + for (short i = 0; i < BN; i += TROWS) { + if ((start_row + i) < params->O) { + if (bj + vec_size <= remaining_k) { +#pragma clang loop unroll(full) + for (short j = 0; j < vec_size; j++) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } + } else { + for (short j = 0; j < vec_size; j++) { + if (bj + j < remaining_k) { + dst[i * dst_ld + j] = curr_src[i * src_ld + j]; + } else { + dst[i * dst_ld + j] = T(0); + } + } + } + } else { #pragma clang loop unroll(full) for (short j = 0; j < vec_size; j++) { dst[i * dst_ld + j] = T(0); @@ -209,6 +292,7 @@ struct Conv2DWeightBlockLoaderGeneral { } } +constant bool align_C [[function_constant(200)]]; template < typename T, int BM, @@ -302,16 +386,41 @@ implicit_gemm_conv_2d_general( simd_gid, simd_lid); mma_t mma_op(simd_gid, simd_lid); - int gemm_k_iterations = - base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; - for (int k = 0; k < gemm_k_iterations; k++) { - threadgroup_barrier(mem_flags::mem_threadgroup); - loader_a.load_unsafe(); - loader_b.load_unsafe(); - threadgroup_barrier(mem_flags::mem_threadgroup); - mma_op.mma(As, Bs); - loader_a.next(); - loader_b.next(); + if (align_C) { + int gemm_k_iterations = + base_wh_size * base_ww_size * gemm_params->gemm_k_iterations; + for (int k = 0; k < gemm_k_iterations; k++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + } + else { + for (int k = 1; k < gemm_params->gemm_k_iterations; k++) { + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_unsafe(); + loader_b.load_unsafe(); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } + } + const short remaining_k = params->C % BK; + for (int j = 0; j < base_wh_size * base_ww_size; j++) { + threadgroup_barrier(mem_flags::mem_threadgroup); + loader_a.load_safe(remaining_k); + loader_b.load_safe(remaining_k); + threadgroup_barrier(mem_flags::mem_threadgroup); + mma_op.mma(As, Bs); + loader_a.next(); + loader_b.next(); + } } threadgroup_barrier(mem_flags::mem_none); { diff --git a/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp b/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp index a3176dc2..6a829e9b 100644 --- a/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp +++ b/Source/Cmlx/mlx-generated/steel_gemm_fused.cpp @@ -26,8 +26,8 @@ template < device T* D [[buffer(3)]], const constant GEMMParams* params [[buffer(4)]], const constant GEMMAddMMParams* addmm_params [[buffer(5), function_constant(use_out_source)]], - const constant int* batch_shape [[buffer(6)]], - const constant int64_t* batch_strides [[buffer(7)]], + const constant int* batch_shape [[buffer(6), function_constant(has_batch)]], + const constant int64_t* batch_strides [[buffer(7), function_constant(has_batch)]], uint simd_lane_id [[thread_index_in_simdgroup]], uint simd_group_id [[simdgroup_index_in_threadgroup]], uint3 tid [[threadgroup_position_in_grid]], diff --git a/Source/Cmlx/mlx-generated/ternary.cpp b/Source/Cmlx/mlx-generated/ternary.cpp index 143ee0d4..7e760273 100644 --- a/Source/Cmlx/mlx-generated/ternary.cpp +++ b/Source/Cmlx/mlx-generated/ternary.cpp @@ -2,25 +2,44 @@ namespace mlx::core::metal { const char* ternary() { return R"preamble( -template +template ::n> [[kernel]] void ternary_v( device const bool* a, device const T* b, device const T* c, device T* d, + constant uint& size, uint index [[thread_position_in_grid]]) { - d[index] = Op()(a[index], b[index], c[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[index + i] = Op()(a[index + i], b[index + i], c[index + i]); + } + } } -template +template ::n> [[kernel]] void ternary_v2( device const bool* a, device const T* b, device const T* c, device T* d, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - d[offset] = Op()(a[offset], b[offset], c[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + d[offset + i] = Op()(a[offset + i], b[offset + i], c[offset + i]); + } + } } template [[kernel]] void ternary_g_nd1( diff --git a/Source/Cmlx/mlx-generated/unary.cpp b/Source/Cmlx/mlx-generated/unary.cpp index bb5a5867..c55daadd 100644 --- a/Source/Cmlx/mlx-generated/unary.cpp +++ b/Source/Cmlx/mlx-generated/unary.cpp @@ -2,21 +2,40 @@ namespace mlx::core::metal { const char* unary() { return R"preamble( -template +template ::n> [[kernel]] void unary_v( device const T* in, device U* out, + constant uint& size, uint index [[thread_position_in_grid]]) { - out[index] = Op()(in[index]); + index *= N; + if (N > 1 && index + N > size) { + for (int i = 0; index + i < size; ++i) { + out[index + i] = Op()(in[index + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[index + i] = Op()(in[index + i]); + } + } } -template +template ::n> [[kernel]] void unary_v2( device const T* in, device U* out, + constant int64_t& size, uint2 index [[thread_position_in_grid]], uint2 grid_dim [[threads_per_grid]]) { - auto offset = index.x + grid_dim.x * int64_t(index.y); - out[offset] = Op()(in[offset]); + int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y)); + if (N > 1 && offset + N > size) { + for (int i = 0; offset + i < size; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } else { + for (int i = 0; i < N; ++i) { + out[offset + i] = Op()(in[offset + i]); + } + } } template < typename T, diff --git a/Source/Cmlx/mlx-generated/utils.cpp b/Source/Cmlx/mlx-generated/utils.cpp index 73eebac4..e8a66b47 100644 --- a/Source/Cmlx/mlx-generated/utils.cpp +++ b/Source/Cmlx/mlx-generated/utils.cpp @@ -310,6 +310,11 @@ static constant constexpr int RMS_LOOPED_LIMIT = 4096; typedef half float16_t; template +struct WorkPerThread { + static_assert(sizeof(U) <= 8, "Type too large"); + static constexpr int constant n = 8 / sizeof(U); +}; +template struct Limits { static const constant U max = metal::numeric_limits::max(); static const constant U min = metal::numeric_limits::min();