Skip to content

chore: update mlx to 0.26.2 #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,5 @@ fastlane/test_output

iOSInjectionProject/
.swiftpm

.vscode/
6 changes: 4 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ let package = Package(
"mlx/tests",

// opt-out of these backends (using metal)
"mlx/mlx/backend/no_metal",
"mlx/mlx/backend/no_cpu",
"mlx/mlx/backend/no_gpu",
"mlx/mlx/backend/cuda",
"mlx/mlx/backend/metal/no_metal.cpp",

// build variants (we are opting _out_ of these)
"mlx/mlx/io/no_safetensors.cpp",
Expand Down Expand Up @@ -111,7 +113,7 @@ let package = Package(
.define("_METAL_"),
.define("SWIFTPM_BUNDLE", to: "\"mlx-swift_Cmlx\""),
.define("METAL_PATH", to: "\"default.metallib\""),
.define("MLX_VERSION", to: "\"0.24.2\""),
.define("MLX_VERSION", to: "\"0.26.2\""),
],
linkerSettings: [
.linkedFramework("Foundation"),
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
Submodule mlx updated 289 files
87 changes: 72 additions & 15 deletions Source/Cmlx/mlx-generated/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,59 +10,116 @@ template <typename T, typename U, typename Op>
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[0], b[index + i]);
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[0]);
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
constant uint& size,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[index + i] = Op()(a[index + i], b[index + i]);
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[0], b[offset]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[0], b[offset + i]);
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[0]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[0]);
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
int64_t offset = index.x + grid_dim.x * int64_t(index.y);
c[offset] = Op()(a[offset], b[offset]);
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
} else {
for (int i = 0; i < N; ++i) {
c[offset + i] = Op()(a[offset + i], b[offset + i]);
}
}
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd1(
Expand Down
7 changes: 7 additions & 0 deletions Source/Cmlx/mlx-generated/binary_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ struct Power {
}
template <>
complex64_t operator()(complex64_t x, complex64_t y) {
if (x.real == 0 && x.imag == 0) {
if (metal::isnan(y.real) || metal::isnan(y.imag)) {
auto nan = metal::numeric_limits<float>::quiet_NaN();
return {nan, nan};
}
return {0.0, 0.0};
}
auto x_theta = metal::atan2(x.imag, x.real);
auto x_ln_r = 0.5 * metal::log(x.real * x.real + x.imag * x.imag);
auto mag = metal::exp(y.real * x_ln_r - y.imag * x_theta);
Expand Down
123 changes: 96 additions & 27 deletions Source/Cmlx/mlx-generated/binary_two.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,77 +13,146 @@ template <typename T, typename U, typename Op>
c[index] = out[0];
d[index] = out[1];
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[0], b[index]);
c[index] = out[0];
d[index] = out[1];
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[0]);
c[index] = out[0];
d[index] = out[1];
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[0]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant uint& size,
uint index [[thread_position_in_grid]]) {
auto out = Op()(a[index], b[index]);
c[index] = out[0];
d[index] = out[1];
index *= N;
if (N > 1 && index + N > size) {
for (int i = 0; index + i < size; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[index + i], b[index + i]);
c[index + i] = out[0];
d[index + i] = out[1];
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_sv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[0], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[0], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vs2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[0]);
c[offset] = out[0];
d[offset] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[0]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}
template <typename T, typename U, typename Op>
template <typename T, typename U, typename Op, int N = WorkPerThread<T>::n>
[[kernel]] void binary_vv2(
device const T* a,
device const T* b,
device U* c,
device U* d,
constant int64_t& size,
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto offset = index.x + grid_dim.x * int64_t(index.y);
auto out = Op()(a[offset], b[offset]);
c[offset] = out[0];
d[offset] = out[1];
int64_t offset = N * (index.x + grid_dim.x * int64_t(index.y));
if (N > 1 && offset + N > size) {
for (int i = 0; offset + i < size; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
} else {
for (int i = 0; i < N; ++i) {
auto out = Op()(a[offset + i], b[offset + i]);
c[offset + i] = out[0];
d[offset + i] = out[1];
}
}
}
template <typename T, typename U, typename Op, typename IdxT = int64_t>
[[kernel]] void binary_g_nd1(
Expand Down
8 changes: 5 additions & 3 deletions Source/Cmlx/mlx-generated/conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ struct Conv2DWeightBlockLoader {
const device T* src;
const constant MLXConvParams<2>* params;
int weight_hw;
int weight_step;
const int read_n;
const bool do_read;
METAL_FUNC Conv2DWeightBlockLoader(
Expand All @@ -371,6 +372,7 @@ struct Conv2DWeightBlockLoader {
src(src_ + bi * src_ld + bj),
params(params_),
weight_hw(0),
weight_step(params->C / params->groups),
read_n(offsets.y + bi),
do_read(read_n + n_rows * TROWS <= gemm_params_->N) {}
METAL_FUNC void load_unsafe() const {
Expand Down Expand Up @@ -400,11 +402,11 @@ struct Conv2DWeightBlockLoader {
}
METAL_FUNC void next() {
if (++weight_hw < (params->wS[1] * params->wS[0])) {
src += params->wt_strides[2];
src += weight_step;
return;
}
weight_hw = 0;
src += BK - (params->wS[1] * params->wS[0] - 1) * params->wt_strides[2];
src += BK - (params->wS[1] * params->wS[0] - 1) * weight_step;
}
};
}
Expand Down Expand Up @@ -604,7 +606,7 @@ struct Conv2DWeightBlockLoaderSmallChannels {
}
return;
}
const device T* curr_src = src + weight_hw * params->wt_strides[2];
const device T* curr_src = src + weight_hw * (params->C / params->groups);
if (BN != 8 || do_read) {
#pragma clang loop unroll(full)
for (short i = 0; i < BROWS; i += TROWS) {
Expand Down
Loading