Skip to content

Commit 57c589c

Browse files
authored
[ performance ] More stack safety in the Prelude (#2704)
1 parent b1f2eab commit 57c589c

File tree

7 files changed

+341
-59
lines changed

7 files changed

+341
-59
lines changed

libs/base/Data/List.idr

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,10 @@ public export
445445
replaceOn : Eq a => (e : a) -> (b : a) -> (l : List a) -> List a
446446
replaceOn e = replaceWhen (== e)
447447

448+
replicateTR : List a -> Nat -> a -> List a
449+
replicateTR as Z _ = as
450+
replicateTR as (S n) x = replicateTR (x :: as) n x
451+
448452
||| Construct a list with `n` copies of `x`.
449453
|||
450454
||| @ n how many copies
@@ -454,6 +458,10 @@ replicate : (n : Nat) -> (x : a) -> List a
454458
replicate Z _ = []
455459
replicate (S n) x = x :: replicate n x
456460

461+
-- Data.List.replicateTRIsReplicate proves these are equivalent.
462+
%transform "tailRecReplicate" List.replicate = List.replicateTR Nil
463+
464+
457465
||| Compute the intersect of two lists by user-supplied equality predicate.
458466
export
459467
intersectBy : (a -> a -> Bool) -> List a -> List a -> List a
@@ -1066,3 +1074,49 @@ export
10661074
mapAppend : (f : a -> b) -> (xs, ys : List a) -> map f (xs ++ ys) = map f xs ++ map f ys
10671075
mapAppend f [] ys = Refl
10681076
mapAppend f (x::xs) ys = rewrite mapAppend f xs ys in Refl
1077+
1078+
0 mapTRIsMap : (f : a -> b) -> (as : List a) -> mapTR f as === map f as
1079+
mapTRIsMap f = lemma Lin
1080+
where lemma : (sb : SnocList b)
1081+
-> (as : List a)
1082+
-> mapAppend sb f as === (sb <>> map f as)
1083+
lemma sb [] = Refl
1084+
lemma sb (x :: xs) = lemma (sb :< f x) xs
1085+
1086+
1087+
0 mapMaybeTRIsMapMaybe : (f : a -> Maybe b)
1088+
-> (as : List a)
1089+
-> mapMaybeTR f as === mapMaybe f as
1090+
mapMaybeTRIsMapMaybe f = lemma Lin
1091+
where lemma : (sb : SnocList b)
1092+
-> (as : List a)
1093+
-> mapMaybeAppend sb f as === (sb <>> mapMaybe f as)
1094+
lemma sb [] = Refl
1095+
lemma sb (x :: xs) with (f x)
1096+
lemma sb (x :: xs) | Nothing = lemma sb xs
1097+
lemma sb (x :: xs) | Just v = lemma (sb :< v) xs
1098+
1099+
0 filterTRIsFilter : (f : a -> Bool)
1100+
-> (as : List a)
1101+
-> filterTR f as === filter f as
1102+
filterTRIsFilter f = lemma Lin
1103+
1104+
where lemma : (sa : SnocList a)
1105+
-> (as : List a)
1106+
-> filterAppend sa f as === (sa <>> filter f as)
1107+
lemma sa [] = Refl
1108+
lemma sa (x :: xs) with (f x)
1109+
lemma sa (x :: xs) | False = lemma sa xs
1110+
lemma sa (x :: xs) | True = lemma (sa :< x) xs
1111+
1112+
0 replicateTRIsReplicate : (n : Nat) -> (x : a) -> replicateTR [] n x === replicate n x
1113+
replicateTRIsReplicate n x = trans (lemma [] n) (appendNilRightNeutral _)
1114+
where lemma1 : (as : List a) -> (m : Nat) -> (x :: replicate m x) ++ as === replicate m x ++ (x :: as)
1115+
lemma1 as 0 = Refl
1116+
lemma1 as (S k) = cong (x ::) (lemma1 as k)
1117+
1118+
lemma : (as : List a) -> (m : Nat) -> replicateTR as m x === replicate m x ++ as
1119+
lemma as 0 = Refl
1120+
lemma as (S k) =
1121+
let prf := lemma (x :: as) k
1122+
in trans prf (sym $ lemma1 as k)

libs/base/Data/SnocList.idr

Lines changed: 91 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,26 @@ Show a => Show (SnocList a) where
6262
show' acc (xs :< x) = show' (show x :: acc) xs
6363

6464
public export
65+
mapImpl : (a -> b) -> SnocList a -> SnocList b
66+
mapImpl f Lin = Lin
67+
mapImpl f (sx :< x) = (mapImpl f sx) :< (f x)
68+
69+
-- Utility for implementing `mapTR`
70+
mapTR' : List b -> (a -> b) -> SnocList a -> SnocList b
71+
mapTR' xs f (sx :< x) = mapTR' (f x :: xs) f sx
72+
mapTR' xs f Lin = Lin <>< xs
73+
74+
-- Tail recursive version of `map`. This is automatically used
75+
-- at runtime due to a `transform` rule.
76+
mapTR : (a -> b) -> SnocList a -> SnocList b
77+
mapTR = mapTR' []
78+
79+
-- mapTRIsMap proves these are equivalent.
80+
%transform "tailRecMapSnocList" SnocList.mapImpl = SnocList.mapTR
81+
82+
public export %inline
6583
Functor SnocList where
66-
map f Lin = Lin
67-
map f (sx :< x) = (map f sx) :< (f x)
84+
map = mapImpl
6885

6986
public export
7087
Semigroup (SnocList a) where
@@ -357,3 +374,75 @@ mapMaybeCast f (x::xs) = do
357374
mapMaybeStepLemma with (f x)
358375
_ | Nothing = rewrite appendLinLeftNeutral $ [<] <>< mapMaybe f xs in Refl
359376
_ | (Just y) = rewrite fishAsSnocAppend [<y] (mapMaybe f xs) in Refl
377+
378+
0 mapTRIsMap : (f : a -> b) -> (sa : SnocList a) -> mapTR f sa === map f sa
379+
mapTRIsMap f = lemma []
380+
where lemma : (bs : List b)
381+
-> (sa : SnocList a)
382+
-> mapTR' bs f sa === (map f sa <>< bs)
383+
lemma bs Lin = Refl
384+
lemma bs (sx :< x) = lemma (f x :: bs) sx
385+
386+
387+
0 mapMaybeTRIsMapMaybe : (f : a -> Maybe b)
388+
-> (sa : SnocList a)
389+
-> mapMaybeTR f sa === mapMaybe f sa
390+
mapMaybeTRIsMapMaybe f = lemma []
391+
where lemma : (bs : List b)
392+
-> (sa : SnocList a)
393+
-> mapMaybeAppend bs f sa === (mapMaybe f sa <>< bs)
394+
lemma bs Lin = Refl
395+
lemma bs (sx :< x) with (f x)
396+
lemma bs (sx :< x) | Nothing = lemma bs sx
397+
lemma bs (sx :< x) | Just v = lemma (v :: bs) sx
398+
399+
0 filterTRIsFilter : (f : a -> Bool)
400+
-> (sa : SnocList a)
401+
-> filterTR f sa === filter f sa
402+
filterTRIsFilter f = lemma []
403+
where lemma : (as : List a)
404+
-> (sa : SnocList a)
405+
-> filterAppend as f sa === (filter f sa <>< as)
406+
lemma as Lin = Refl
407+
lemma as (sx :< x) with (f x)
408+
lemma as (sx :< x) | False = lemma as sx
409+
lemma as (sx :< x) | True = lemma (x :: as) sx
410+
411+
-- SnocList `reverse` applied to `reverseOnto` is equivalent to swapping the
412+
-- arguments of `reverseOnto`.
413+
reverseReverseOnto : (l, r : SnocList a) ->
414+
reverse (reverseOnto l r) = reverseOnto r l
415+
reverseReverseOnto _ Lin = Refl
416+
reverseReverseOnto l (sx :< x) = reverseReverseOnto (l :< x) sx
417+
418+
||| SnocList `reverse` applied twice yields the identity function.
419+
export
420+
reverseInvolutive : (sx : SnocList a) -> reverse (reverse sx) = sx
421+
reverseInvolutive = reverseReverseOnto Lin
422+
423+
-- Appending `x` to `l` and then reversing the result onto `r` is the same as
424+
-- using (::) with `x` and the result of reversing `l` onto `r`.
425+
snocReverse : (x : a) -> (l, r : SnocList a) ->
426+
reverseOnto r l :< x = reverseOnto r (reverseOnto [<x] (reverse l))
427+
snocReverse _ Lin _ = Refl
428+
snocReverse x (sy :< y) r
429+
= rewrite snocReverse x sy (r :< y) in
430+
rewrite cong (reverseOnto r . reverse) $ snocReverse x sy [<y] in
431+
rewrite reverseInvolutive (reverseOnto [<x] (reverse sy) :< y) in
432+
Refl
433+
434+
-- Proof that it is safe to lift a (::) out of the first `tailRecAppend`
435+
-- argument.
436+
snocTailRecAppend : (x : a) -> (l, r : SnocList a) ->
437+
tailRecAppend l (r :< x) = (tailRecAppend l r) :< x
438+
snocTailRecAppend x l r =
439+
rewrite snocReverse x (reverse r) l in
440+
rewrite reverseInvolutive r in
441+
Refl
442+
443+
-- Proof that `(++)` and `tailRecAppend` do the same thing, so the %transform
444+
-- directive is safe.
445+
tailRecAppendIsAppend : (sx, sy : SnocList a) -> tailRecAppend sx sy = sx ++ sy
446+
tailRecAppendIsAppend sx Lin = Refl
447+
tailRecAppendIsAppend sx (sy :< y) =
448+
trans (snocTailRecAppend y sx sy) (cong (:< y) $ tailRecAppendIsAppend sx sy)

0 commit comments

Comments
 (0)