Skip to content

Commit 6c1c4e7

Browse files
committed
fixup GPU printing
1 parent da9c875 commit 6c1c4e7

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

src/onehot.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,16 @@ end
5757

5858
# this is from /LinearAlgebra/src/diagonal.jl, official way to print the dots:
5959
function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::AbstractString)
60-
# CUDA.@allowscalar(x[i,j]) ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
6160
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
6261
end
6362
function Base.replace_in_print_matrix(x::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike}, i::Integer, j::Integer, s::AbstractString)
64-
CUDA.@allowscalar(x[i,j]) ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s
63+
x[i,j] ? s : _isonehot(parent(x)) ? Base.replace_with_centered_mark(s) : s
6564
end
6665

67-
# Base.show(io::IO, x::OneHotLike) = show(io, convert(Array{Bool}, cpu(x))) # helps string(cu(y))
68-
6966
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
7067
Base.print_array(io, cpu(X))
68+
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
69+
Base.print_array(io, cpu(X))
7170

7271
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
7372
_onehot_bool_type(x::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}

0 commit comments

Comments
 (0)