91
91
end
92
92
93
93
94
- @generated function * (A:: StaticMatrix , b:: StaticVector )
95
- TA = eltype (A)
96
- Tb = eltype (b)
94
+ @generated function * {TA,Tb}(A:: StaticMatrix{TA} , b:: StaticVector{Tb} )
95
+ sA = size (A)
96
+ sb = size (b)
97
+
98
+ s = (sA[1 ],)
99
+ T = promote_op (matprod, TA, Tb)
100
+ # println(T)
101
+
102
+ if sb[1 ] != sA[2 ]
103
+ error (" Dimension mismatch" )
104
+ end
105
+
106
+ if s == sb
107
+ if T == Tb
108
+ newtype = b
109
+ else
110
+ newtype = similar_type (b, T)
111
+ end
112
+ else
113
+ if T == Tb
114
+ newtype = similar_type (b, s)
115
+ else
116
+ newtype = similar_type (b, T, s)
117
+ end
118
+ end
119
+
120
+ if sA[2 ] != 0
121
+ exprs = [reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(A[$ (sub2ind (sA, k, j))]* b[$ j]) for j = 1 : sA[2 ]]) for k = 1 : sA[1 ]]
122
+ else
123
+ exprs = [zero (T) for k = 1 : sA[1 ]]
124
+ end
125
+
126
+ return quote
127
+ $ (Expr (:meta ,:inline ))
128
+ @inbounds return $ (Expr (:call , newtype, Expr (:tuple , exprs... )))
129
+ end
130
+ end
131
+
132
+ # For an ambiguity relating to the below two functions
133
+ @generated function * {TA<: Base.LinAlg.BlasFloat ,Tb}(A:: StaticMatrix{TA} , b:: StaticVector{Tb} )
97
134
sA = size (A)
98
135
sb = size (b)
99
136
132
169
end
133
170
134
171
# This happens to be size-inferrable from A
135
- @generated function * (A:: StaticMatrix , b:: AbstractVector )
136
- TA = eltype (A)
137
- Tb = eltype (b)
172
+ @generated function * {TA,Tb}(A:: StaticMatrix{TA} , b:: AbstractVector{Tb} )
138
173
sA = size (A)
139
174
# sb = size(b)
140
175
@@ -457,11 +492,52 @@ end
457
492
7
458
493
459
494
# TODO aliasing problems if c === b?
460
- @generated function A_mul_B! (c:: StaticVector , A:: StaticMatrix , b:: StaticVector )
495
+ @generated function A_mul_B! {T1,T2,T3} (c:: StaticVector{T1} , A:: StaticMatrix{T2} , b:: StaticVector{T3} )
496
+ sA = size (A)
497
+ sb = size (b)
498
+ s = size (c)
499
+
500
+ if sb[1 ] != sA[2 ] || s[1 ] != sA[1 ]
501
+ error (" Dimension mismatch" )
502
+ end
503
+
504
+ if sA[2 ] != 0
505
+ exprs = [:(c[$ k] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(A[$ (sub2ind (sA, k, j))]* b[$ j]) for j = 1 : sA[2 ]]))) for k = 1 : sA[1 ]]
506
+ else
507
+ exprs = [:(c[$ k] = $ (zero (T1))) for k = 1 : sA[1 ]]
508
+ end
509
+
510
+ return quote
511
+ $ (Expr (:meta ,:inline ))
512
+ @inbounds $ (Expr (:block , exprs... ))
513
+ end
514
+ end
515
+
516
+ # These two for ambiguity with a BLAS calling function
517
+ @generated function A_mul_B! {T<:Union{Float32, Float64}} (c:: StaticVector{T} , A:: StaticMatrix{T} , b:: StaticVector{T} )
518
+ sA = size (A)
519
+ sb = size (b)
520
+ s = size (c)
521
+
522
+ if sb[1 ] != sA[2 ] || s[1 ] != sA[1 ]
523
+ error (" Dimension mismatch" )
524
+ end
525
+
526
+ if sA[2 ] != 0
527
+ exprs = [:(c[$ k] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(A[$ (sub2ind (sA, k, j))]* b[$ j]) for j = 1 : sA[2 ]]))) for k = 1 : sA[1 ]]
528
+ else
529
+ exprs = [:(c[$ k] = $ (zero (T))) for k = 1 : sA[1 ]]
530
+ end
531
+
532
+ return quote
533
+ $ (Expr (:meta ,:inline ))
534
+ @inbounds $ (Expr (:block , exprs... ))
535
+ end
536
+ end
537
+ @generated function A_mul_B! {T<:Union{Complex{Float32}, Complex{Float64}}} (c:: StaticVector{T} , A:: StaticMatrix{T} , b:: StaticVector{T} )
461
538
sA = size (A)
462
539
sb = size (b)
463
540
s = size (c)
464
- T = eltype (c)
465
541
466
542
if sb[1 ] != sA[2 ] || s[1 ] != sA[1 ]
467
543
error (" Dimension mismatch" )
480
556
end
481
557
482
558
# The unrolled code is inferrable from the size of A
483
- @generated function A_mul_B! (c:: AbstractVector , A:: StaticMatrix , b:: AbstractVector )
559
+ @generated function A_mul_B! {T1,T2,T3} (c:: AbstractVector{T1} , A:: StaticMatrix{T2} , b:: AbstractVector{T3} )
484
560
sA = size (A)
485
561
T = eltype (c)
486
562
@@ -500,11 +576,30 @@ end
500
576
end
501
577
502
578
# Ambiguity with a BLAS specialized function
503
- @generated function Base. A_mul_B! {T<:Base.LinAlg.BlasFloat} (c:: StridedVector{T} , A:: StaticMatrix{T} , b:: StridedVector{T} )
579
+ # Also possible bug makes this harder to resolve (see https://github.com/JuliaLang/julia/issues/19124)
580
+ # (problem being that I can't use T<:BlasFloat)
581
+ @generated function A_mul_B! {T<:Union{Float64,Float32}} (c:: StridedVector{T} , A:: StaticMatrix{T} , b:: StridedVector{T} )
582
+ sA = size (A)
583
+
584
+ if sA[2 ] != 0
585
+ exprs = [:(c[$ k] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(A[$ (sub2ind (sA, k, j))]* b[$ j]) for j = 1 : sA[2 ]]))) for k = 1 : sA[1 ]]
586
+ else
587
+ exprs = [:(c[$ k] = $ (zero (T))) for k = 1 : sA[1 ]]
588
+ end
589
+
590
+ return quote
591
+ $ (Expr (:meta ,:inline ))
592
+ if length (b) != $ (sA[2 ]) || length (c) != $ (sA[1 ])
593
+ error (" Dimension mismatch" )
594
+ end
595
+ @inbounds $ (Expr (:block , exprs... ))
596
+ end
597
+ end
598
+ @generated function A_mul_B! {T<:Union{Complex{Float64},Complex{Float32}}} (c:: StridedVector{T} , A:: StaticMatrix{T} , b:: StridedVector{T} )
504
599
sA = size (A)
505
600
506
601
if sA[2 ] != 0
507
- exprs = [:(c[$ k] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(2 * A[$ (sub2ind (sA, k, j))]* b[$ j]) for j = 1 : sA[2 ]]))) for k = 1 : sA[1 ]]
602
+ exprs = [:(c[$ k] = $ (reduce ((ex1,ex2) -> :(+ ($ ex1,$ ex2)), [:(A[$ (sub2ind (sA, k, j))]* b[$ j]) for j = 1 : sA[2 ]]))) for k = 1 : sA[1 ]]
508
603
else
509
604
exprs = [:(c[$ k] = $ (zero (T))) for k = 1 : sA[1 ]]
510
605
end
0 commit comments