Skip to content

Commit 5c81de7

Browse files
committed
Express sums using the new literal type
1 parent a7bc65a commit 5c81de7

File tree

6 files changed

+37
-33
lines changed

6 files changed

+37
-33
lines changed

docs/src/examples.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,29 +108,29 @@ expr = sum(x .* y)
108108
109109
# output
110110
111-
x¹y¹δ¹
111+
x¹y¹1
112112
```
113113
```jldoctest usage
114114
expr = sum(A * x)
115115
116116
# output
117117
118-
A¹₄x⁴δ¹
118+
A¹₄x⁴1
119119
```
120120
#### Vector Norms
121121
```jldoctest usage
122122
expr = norm2(A * x)
123123
124124
# output
125125
126-
((A¹₄x⁴).^2δ¹₁).^1//2
126+
((A¹₄x⁴).^21₁).^1//2
127127
```
128128
```jldoctest usage
129129
expr = norm1(A * x)
130130
131131
# output
132132
133-
|A¹₄x⁴|δ¹
133+
|A¹₄x⁴|1
134134
```
135135
#### Matrix Trace
136136
```jldoctest usage

src/ir.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ function to_ir(arg::KrD)
243243

244244
ids = get_indices(arg)
245245

246-
if typeof(ids[1]) == Upper && typeof(ids[2]) == Lower
246+
if flip(ids[1]) == ids[2]
247+
return ir.Trace(ir.Identity())
248+
elseif typeof(ids[1]) == Upper && typeof(ids[2]) == Lower
247249
return ir.Identity()
248250
elseif typeof(ids[1]) == Lower && typeof(ids[2]) == Upper
249251
return ir.Transpose(ir.Identity())
@@ -372,9 +374,8 @@ function to_ir(arg::BinaryOperation{Mult})
372374
else
373375
return ir.Transpose(ir.Vec(ir.Const(1)))
374376
end
375-
elseif (arg.arg1 isa KrD && is_trace(arg.arg1)) ||
376-
(arg.arg2 isa KrD && is_trace(arg.arg2))
377-
tensor = if arg.arg1 isa KrD
377+
elseif arg.arg1 isa Literal || arg.arg2 isa Literal
378+
tensor = if arg.arg1 isa Literal
378379
arg.arg2
379380
else
380381
arg.arg1
@@ -392,8 +393,8 @@ function to_ir(arg::BinaryOperation{Mult})
392393
return ir.Trace(ir.Product(to_ir(arg.arg1), to_ir(arg.arg2)))
393394
end
394395

395-
if isempty(target_indices) && (typeof(terms[1]) == KrD || typeof(terms[2]) == KrD)
396-
tensor = if typeof(first(terms)) == KrD
396+
if isempty(target_indices) && (first(terms) isa Literal || last(terms) isa Literal)
397+
tensor = if first(terms) == Literal
397398
last(terms)
398399
else
399400
first(terms)

src/ricci.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ function Base.sum(arg::Tensor)
330330
throw(DomainError("Sum is defined only for vectors"))
331331
end
332332

333-
return BinaryOperation{Mult}(arg, KrD(first(free_ids), flip(first(free_ids))))
333+
return BinaryOperation{Mult}(arg, Literal(1, flip(only(free_ids))))
334334
end
335335

336336
function Base.broadcasted(::typeof(abs), arg::Tensor)

src/simplify.jl

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ function to_binary_operation(op::Op, terms::AbstractArray) where {Op}
7171
return BinaryOperation{Op}(to_binary_operation(op, terms[1:(end-1)]), terms[end])
7272
end
7373

74-
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
74+
function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::Literal)
7575
if is_diag(arg1)
7676
d = get_diag_delta(arg1)
7777

@@ -103,21 +103,16 @@ function simplify(::Mult, arg1::BinaryOperation{Mult}, arg2::KrD)
103103
return to_binary_operation(Mult(), reshaped)
104104
end
105105

106-
if is_trace(arg2)
107-
s = first(arg2.indices)
108-
106+
if can_contract(arg1, arg2) && length(get_free_indices(arg2)) == 1
109107
elwise_ids = elementwise_indices(arg1.arg1, arg1.arg2)
110-
last_index = get_last_letter(union(get_free_indices(arg1), get_free_indices(arg2)))
111-
112-
if !isempty(elwise_ids)
113-
if s elwise_ids || flip(s) elwise_ids
114-
if last_index get_free_indices(arg1.arg1)
115-
return evaluate(BinaryOperation{Mult}(arg1.arg1, adjoint(arg1.arg2)))
116-
elseif last_index get_free_indices(arg1.arg2)
117-
return evaluate(BinaryOperation{Mult}(adjoint(arg1.arg1), arg1.arg2))
118-
end
119-
120-
@assert false "Unreachable"
108+
remaining_index =
109+
eliminate_indices(union(get_free_indices(arg1), get_free_indices(arg2)))
110+
111+
if length(elwise_ids) == 1 && length(remaining_index) == 1
112+
if only(remaining_index) get_free_indices(arg1.arg1)
113+
return evaluate(BinaryOperation{Mult}(arg1.arg1, adjoint(arg1.arg2)))
114+
elseif only(remaining_index) get_free_indices(arg1.arg2)
115+
return evaluate(BinaryOperation{Mult}(adjoint(arg1.arg1), arg1.arg2))
121116
end
122117
end
123118
end

test/StdStrTest.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
@test to_std(gradient(sum(2 * cos.(A * x + y)), x)) == "(-2)Aᵀsin(Ax + y)"
3232
@test to_std(gradient(sum(x .^ 2), x)) == "2x"
3333
@test to_std(gradient(sum(x .^ 3), x)) == "3x^2"
34-
@test to_std(gradient(sum(x)^2, x)) == "2sum(xᵀ)Iᵀvec(1)" # TODO: Simplify and remove Iᵀ
34+
@test to_std(gradient(sum(x)^2, x)) == "2sum(xᵀ)vec(1)"
3535
@test to_std(gradient(sum(x .^ 2)^2, x)) == "4sum(xᵀ^2)x"
3636
@test to_std(gradient(sum((x + y) .^ 2), x)) == "2(x + y)"
3737
@test to_std(gradient(sum((x .* y) .^ 2), x)) == "2(x ⊙ y ⊙ y)"
@@ -61,7 +61,7 @@ end
6161
@matrix A B C X
6262
@vector x y z
6363

64-
@test to_std(derivative(sum(-y .* (X*z)), X)) == "(-1)zyᵀ"
64+
@test_broken to_std(derivative(sum(-y .* (X*z)), X)) == "(-1)zyᵀ"
6565
@test to_std(derivative(sum((A .* B) * C * x), x)) == "vec(1)ᵀ(A ⊙ B)C"
6666
end
6767

test/StdTest.jl

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,18 +247,26 @@ end
247247
@test to_std(mul(mul(z, v), mul(x, y))) == "z ⊙ v ⊙ x ⊙ y"
248248
end
249249

250-
@testset "to_std output is correct with vector sum" begin
250+
@testset "to_std output is correct with vector and trace of identity" begin
251251
x = Variable("x", Upper(1))
252252
y = Variable("y", Lower(2))
253253

254254
function mul(l, r)
255255
return dc.BinaryOperation{dc.Mult}(l, r)
256256
end
257257

258-
@test to_std(mul(x, KrD(Upper(1), Lower(1)))) == "sum(x)"
259-
@test to_std(mul(KrD(Upper(1), Lower(1)), x)) == "sum(x)"
260-
@test to_std(mul(y, KrD(Upper(2), Lower(2)))) == "sum(yᵀ)"
261-
@test to_std(mul(KrD(Upper(2), Lower(2)), y)) == "sum(yᵀ)"
258+
# Such expression are not ever created by * but if they were,
259+
# the output should be this.
260+
@test to_std(mul(x, KrD(Upper(1), Lower(1)))) == "tr(I)x"
261+
@test to_std(mul(KrD(Upper(1), Lower(1)), x)) == "tr(I)x"
262+
@test to_std(mul(y, KrD(Upper(2), Lower(2)))) == "tr(I)yᵀ"
263+
@test to_std(mul(KrD(Upper(2), Lower(2)), y)) == "tr(I)yᵀ"
264+
end
265+
266+
@testset "to_std output is correct with vector sum" begin
267+
x = Variable("x", Upper(1))
268+
y = Variable("y", Lower(2))
269+
262270
@test to_std(sum(x)) == "sum(x)"
263271
@test to_std(sum(y)) == "sum(yᵀ)"
264272
end

0 commit comments

Comments
 (0)