diff --git a/sklearn/tree/_classes.py b/sklearn/tree/_classes.py index 2ce58759d8253..dc08bee65d3d9 100644 --- a/sklearn/tree/_classes.py +++ b/sklearn/tree/_classes.py @@ -401,7 +401,7 @@ def _fit( max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes if sample_weight is not None: - sample_weight = _check_sample_weight(sample_weight, X, DOUBLE) + sample_weight = _check_sample_weight(sample_weight, X, dtype=DOUBLE) if y is not None and expanded_class_weight is not None: if sample_weight is not None: diff --git a/sklearn/utils/_cython_blas.pyx b/sklearn/utils/_cython_blas.pyx index c242e59e1b9de..ac23d0c4000ff 100644 --- a/sklearn/utils/_cython_blas.pyx +++ b/sklearn/utils/_cython_blas.pyx @@ -126,8 +126,8 @@ cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha, floating beta, floating *y, int incy) noexcept nogil: """y := alpha * op(A).x + beta * y""" cdef char ta_ = ta - if order == RowMajor: - ta_ = NoTrans if ta == Trans else Trans + if order == BLAS_Order.RowMajor: + ta_ = BLAS_Trans.NoTrans if ta == BLAS_Trans.Trans else BLAS_Trans.Trans if floating is float: sgemv(&ta_, &n, &m, &alpha, A, &lda, x, &incx, &beta, y, &incy) @@ -148,8 +148,10 @@ cpdef _gemv_memview(BLAS_Trans ta, floating alpha, const floating[:, :] A, cdef: int m = A.shape[0] int n = A.shape[1] - BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor - int lda = m if order == ColMajor else n + BLAS_Order order = ( + BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor + ) + int lda = m if order == BLAS_Order.ColMajor else n _gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1) @@ -158,7 +160,7 @@ cdef void _ger(BLAS_Order order, int m, int n, floating alpha, const floating *x, int incx, const floating *y, int incy, floating *A, int lda) noexcept nogil: """A := alpha * x.y.T + A""" - if order == RowMajor: + if order == BLAS_Order.RowMajor: if floating is float: sger(&n, &m, &alpha, y, &incy, x, &incx, A, &lda) else: @@ -175,8 +177,10 @@ cpdef _ger_memview(floating alpha, const floating[::1] x, cdef: int m = A.shape[0] int n = A.shape[1] - BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor - int lda = m if order == ColMajor else n + BLAS_Order order = ( + BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor + ) + int lda = m if order == BLAS_Order.ColMajor else n _ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda) @@ -194,7 +198,7 @@ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n, cdef: char ta_ = ta char tb_ = tb - if order == RowMajor: + if order == BLAS_Order.RowMajor: if floating is float: sgemm(&tb_, &ta_, &n, &m, &k, &alpha, B, &ldb, A, &lda, &beta, C, &ldc) @@ -214,19 +218,21 @@ cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha, const floating[:, :] A, const floating[:, :] B, floating beta, floating[:, :] C): cdef: - int m = A.shape[0] if ta == NoTrans else A.shape[1] - int n = B.shape[1] if tb == NoTrans else B.shape[0] - int k = A.shape[1] if ta == NoTrans else A.shape[0] + int m = A.shape[0] if ta == BLAS_Trans.NoTrans else A.shape[1] + int n = B.shape[1] if tb == BLAS_Trans.NoTrans else B.shape[0] + int k = A.shape[1] if ta == BLAS_Trans.NoTrans else A.shape[0] int lda, ldb, ldc - BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor + BLAS_Order order = ( + BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor + ) - if order == RowMajor: - lda = k if ta == NoTrans else m - ldb = n if tb == NoTrans else k + if order == BLAS_Order.RowMajor: + lda = k if ta == BLAS_Trans.NoTrans else m + ldb = n if tb == BLAS_Trans.NoTrans else k ldc = n else: - lda = m if ta == NoTrans else k - ldb = k if tb == NoTrans else n + lda = m if ta == BLAS_Trans.NoTrans else k + ldb = k if tb == BLAS_Trans.NoTrans else n ldc = m _gemm(order, ta, tb, m, n, k, alpha, &A[0, 0],