Skip to content

Commit d82d294

Browse files
committed
Apply code review
1 parent 7c9af0b commit d82d294

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/onehot.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,19 +187,19 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
187187
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
188188
return A[:, onecold(B)]
189189
end
190-
for math_op in [:Adjoint, :Transpose]
190+
for wrapper in [:Adjoint, :Transpose]
191191
@eval begin
192-
function Base.:*(A::$math_op{T1,<:AbstractArray{T,2}}, b::OneHotVector) where {T1,T}
193-
if size(A, 2) != b.of
194-
throw(DimensionMismatch("Second element of Adjoint matrix size $(size(A)) must correspond with OneHotVector size $(size(b))"))
195-
end
192+
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix}, b::OneHotVector{<:Any, L}) where L
193+
size(A, 2) != L ||
194+
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
195+
196196
return A[:, b.ix]
197197
end
198198

199-
function Base.:*(A::$math_op{T1,<:AbstractArray{T,1}}, b::OneHotVector) where {T1<:Number,T}
200-
if size(A, 2) != b.of
201-
throw(DimensionMismatch("Second element of Adjoint matrix size $(size(A)) must correspond with OneHotVector size $(size(b))"))
202-
end
199+
function Base.:*(A::$wrapper{<:Number, <:AbstractVector}, b::OneHotVector{<:Any, L}) where L
200+
size(A, 2) != L ||
201+
throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))
202+
203203
return A[b.ix]
204204
end
205205
end

0 commit comments

Comments
 (0)