|
48 | 48 | b1 = Flux.OneHotMatrix([1, 1, 2, 2], 3)
|
49 | 49 | b2 = Flux.OneHotMatrix([2, 4, 1, 3], 5)
|
50 | 50 | b3 = Flux.OneHotMatrix([1, 1, 2], 4)
|
| 51 | + b4 = reshape(Flux.OneHotMatrix([1 2 3; 2 2 1], 3), 3, :) |
| 52 | + b5 = reshape(b4, 6, :) |
| 53 | + b6 = reshape(Flux.OneHotMatrix([1 2 2; 2 2 1], 2), 3, :) |
| 54 | + b7 = reshape(Flux.OneHotMatrix([1 2 3; 1 2 3], 3), 6, :) |
51 | 55 |
|
52 | 56 | @test A * b1 == A[:,[1, 1, 2, 2]]
|
53 | 57 | @test b1' * A == Array(b1') * A
|
54 | 58 | @test A' * b1 == A' * Array(b1)
|
55 | 59 | @test A * b3' == A * Array(b3')
|
56 | 60 | @test transpose(X) * b2 == transpose(X) * Array(b2)
|
| 61 | + @test A * b4 == A[:,[1, 2, 2, 2, 3, 1]] |
| 62 | + @test_broken A * b5' == A[:,[1, 2, 2, 2, 3, 1]] |
| 63 | + @test A * b6 == hcat(A[:,1], 2*A[:,2], A[:,2], A[:,1]+A[:,2]) |
| 64 | + @test_broken A * b7' |
| 65 | + |
57 | 66 | @test_throws DimensionMismatch A*b1'
|
58 | 67 | @test_throws DimensionMismatch A*b2
|
59 | 68 | @test_throws DimensionMismatch A*b2'
|
| 69 | + @test_throws DimensionMismatch A*b6' |
| 70 | + @test_throws DimensionMismatch A*b7 |
60 | 71 | end
|
61 | 72 |
|
62 | 73 | @testset "OneHotArray" begin
|
|
0 commit comments