Skip to content

FIX add argument to _check_sample_weight #73

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: submodulev3
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/tree/_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _fit(
max_leaf_nodes = -1 if self.max_leaf_nodes is None else self.max_leaf_nodes

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X, DOUBLE)
sample_weight = _check_sample_weight(sample_weight, X, dtype=DOUBLE)

if y is not None and expanded_class_weight is not None:
if sample_weight is not None:
Expand Down
40 changes: 23 additions & 17 deletions sklearn/utils/_cython_blas.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ cdef void _gemv(BLAS_Order order, BLAS_Trans ta, int m, int n, floating alpha,
floating beta, floating *y, int incy) noexcept nogil:
"""y := alpha * op(A).x + beta * y"""
cdef char ta_ = ta
if order == RowMajor:
ta_ = NoTrans if ta == Trans else Trans
if order == BLAS_Order.RowMajor:
ta_ = BLAS_Trans.NoTrans if ta == BLAS_Trans.Trans else BLAS_Trans.Trans
if floating is float:
sgemv(&ta_, &n, &m, &alpha, <float *> A, &lda, <float *> x,
&incx, &beta, y, &incy)
Expand All @@ -148,8 +148,10 @@ cpdef _gemv_memview(BLAS_Trans ta, floating alpha, const floating[:, :] A,
cdef:
int m = A.shape[0]
int n = A.shape[1]
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
int lda = m if order == ColMajor else n
BLAS_Order order = (
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
)
int lda = m if order == BLAS_Order.ColMajor else n

_gemv(order, ta, m, n, alpha, &A[0, 0], lda, &x[0], 1, beta, &y[0], 1)

Expand All @@ -158,7 +160,7 @@ cdef void _ger(BLAS_Order order, int m, int n, floating alpha,
const floating *x, int incx, const floating *y,
int incy, floating *A, int lda) noexcept nogil:
"""A := alpha * x.y.T + A"""
if order == RowMajor:
if order == BLAS_Order.RowMajor:
if floating is float:
sger(&n, &m, &alpha, <float *> y, &incy, <float *> x, &incx, A, &lda)
else:
Expand All @@ -175,8 +177,10 @@ cpdef _ger_memview(floating alpha, const floating[::1] x,
cdef:
int m = A.shape[0]
int n = A.shape[1]
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
int lda = m if order == ColMajor else n
BLAS_Order order = (
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
)
int lda = m if order == BLAS_Order.ColMajor else n

_ger(order, m, n, alpha, &x[0], 1, &y[0], 1, &A[0, 0], lda)

Expand All @@ -194,7 +198,7 @@ cdef void _gemm(BLAS_Order order, BLAS_Trans ta, BLAS_Trans tb, int m, int n,
cdef:
char ta_ = ta
char tb_ = tb
if order == RowMajor:
if order == BLAS_Order.RowMajor:
if floating is float:
sgemm(&tb_, &ta_, &n, &m, &k, &alpha, <float*>B,
&ldb, <float*>A, &lda, &beta, C, &ldc)
Expand All @@ -214,19 +218,21 @@ cpdef _gemm_memview(BLAS_Trans ta, BLAS_Trans tb, floating alpha,
const floating[:, :] A, const floating[:, :] B, floating beta,
floating[:, :] C):
cdef:
int m = A.shape[0] if ta == NoTrans else A.shape[1]
int n = B.shape[1] if tb == NoTrans else B.shape[0]
int k = A.shape[1] if ta == NoTrans else A.shape[0]
int m = A.shape[0] if ta == BLAS_Trans.NoTrans else A.shape[1]
int n = B.shape[1] if tb == BLAS_Trans.NoTrans else B.shape[0]
int k = A.shape[1] if ta == BLAS_Trans.NoTrans else A.shape[0]
int lda, ldb, ldc
BLAS_Order order = ColMajor if A.strides[0] == A.itemsize else RowMajor
BLAS_Order order = (
BLAS_Order.ColMajor if A.strides[0] == A.itemsize else BLAS_Order.RowMajor
)

if order == RowMajor:
lda = k if ta == NoTrans else m
ldb = n if tb == NoTrans else k
if order == BLAS_Order.RowMajor:
lda = k if ta == BLAS_Trans.NoTrans else m
ldb = n if tb == BLAS_Trans.NoTrans else k
ldc = n
else:
lda = m if ta == NoTrans else k
ldb = k if tb == NoTrans else n
lda = m if ta == BLAS_Trans.NoTrans else k
ldb = k if tb == BLAS_Trans.NoTrans else n
ldc = m

_gemm(order, ta, tb, m, n, k, alpha, &A[0, 0],
Expand Down