Skip to content

Commit 1884e71

Browse files
authored
Merge pull request #1163 from typelevel/topic/lru-cache
Replace SemispaceCache with LRU Cache
2 parents b314c69 + 68ed3ae commit 1884e71

File tree

9 files changed

+310
-242
lines changed

9 files changed

+310
-242
lines changed

modules/core/shared/src/main/scala/Session.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,9 @@ object Session {
428428
ssl: SSL = SSL.None,
429429
parameters: Map[String, String] = Session.DefaultConnectionParameters,
430430
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
431-
commandCache: Int = 1024,
432-
queryCache: Int = 1024,
433-
parseCache: Int = 1024,
431+
commandCache: Int = 2048,
432+
queryCache: Int = 2048,
433+
parseCache: Int = 2048,
434434
readTimeout: Duration = Duration.Inf,
435435
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
436436
): Resource[F, Resource[F, Session[F]]] = {
@@ -470,9 +470,9 @@ object Session {
470470
ssl: SSL = SSL.None,
471471
parameters: Map[String, String] = Session.DefaultConnectionParameters,
472472
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
473-
commandCache: Int = 1024,
474-
queryCache: Int = 1024,
475-
parseCache: Int = 1024,
473+
commandCache: Int = 2048,
474+
queryCache: Int = 2048,
475+
parseCache: Int = 2048,
476476
readTimeout: Duration = Duration.Inf,
477477
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
478478
): Resource[F, Tracer[F] => Resource[F, Session[F]]] = {
@@ -508,9 +508,9 @@ object Session {
508508
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
509509
ssl: SSL = SSL.None,
510510
parameters: Map[String, String] = Session.DefaultConnectionParameters,
511-
commandCache: Int = 1024,
512-
queryCache: Int = 1024,
513-
parseCache: Int = 1024,
511+
commandCache: Int = 2048,
512+
queryCache: Int = 2048,
513+
parseCache: Int = 2048,
514514
readTimeout: Duration = Duration.Inf,
515515
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
516516
): Resource[F, Session[F]] =
@@ -532,9 +532,9 @@ object Session {
532532
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
533533
ssl: SSL = SSL.None,
534534
parameters: Map[String, String] = Session.DefaultConnectionParameters,
535-
commandCache: Int = 1024,
536-
queryCache: Int = 1024,
537-
parseCache: Int = 1024,
535+
commandCache: Int = 2048,
536+
queryCache: Int = 2048,
537+
parseCache: Int = 2048,
538538
readTimeout: Duration = Duration.Inf,
539539
redactionStrategy: RedactionStrategy = RedactionStrategy.OptIn,
540540
): Tracer[F] => Resource[F, Session[F]] =
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright (c) 2018-2024 by Rob Norris and Contributors
2+
// This software is licensed under the MIT License (MIT).
3+
// For more information see LICENSE or https://opensource.org/licenses/MIT
4+
5+
package skunk.data
6+
7+
/**
8+
* Immutable, least recently used cache.
9+
*
10+
* Entries are stored in the `entries` hash map. A numeric stamp is assigned to
11+
* each entry and stored in the `usages` field, which provides a bidirectional
12+
* mapping between stamp and key, sorted by stamp. The `entries` and `usages`
13+
* fields always have the same size.
14+
*
15+
* Upon put and get of an entry, a new stamp is assigned and `usages`
16+
* is updated. Stamps are assigned in ascending order and each stamp is used only once.
17+
* Hence, the head of `usages` contains the least recently used key.
18+
*/
19+
sealed abstract case class Cache[K, V] private (
20+
max: Int,
21+
entries: Map[K, V]
22+
)(usages: SortedBiMap[Long, K],
23+
stamp: Long
24+
) {
25+
assert(entries.size == usages.size)
26+
27+
def size: Int = entries.size
28+
29+
def contains(k: K): Boolean = entries.contains(k)
30+
31+
/**
32+
* Gets the value associated with the specified key.
33+
*
34+
* Accessing an entry makes it the most recently used entry, hence a new cache
35+
* is returned with the target entry updated to reflect the recent access.
36+
*/
37+
def get(k: K): Option[(Cache[K, V], V)] =
38+
entries.get(k) match {
39+
case Some(v) =>
40+
val newUsages = usages + (stamp -> k)
41+
val newCache = Cache(max, entries, newUsages, stamp + 1)
42+
Some(newCache -> v)
43+
case None =>
44+
None
45+
}
46+
47+
/**
48+
* Returns a new cache with the specified entry added along with the
49+
* entry that was evicted, if any.
50+
*
51+
* The evicted value is defined under two circumstances:
52+
* - the cache already contains a different value for the specified key,
53+
* in which case the old pairing is returned
54+
* - the cache has reeached its max size, in which case the least recently
55+
* used value is evicted
56+
*
57+
* Note: if the cache contains (k, v), calling `put(k, v)` does NOT result
58+
* in an eviction, but calling `put(k, v2)` where `v != v2` does.
59+
*/
60+
def put(k: K, v: V): (Cache[K, V], Option[(K, V)]) =
61+
if (max <= 0) {
62+
// max is 0 so immediately evict the new entry
63+
(this, Some((k, v)))
64+
} else if (entries.size >= max && !contains(k)) {
65+
// at max size already and we need to add a new key, hence we must evict
66+
// the least recently used entry
67+
val (lruStamp, lruKey) = usages.head
68+
val newEntries = entries - lruKey + (k -> v)
69+
val newUsages = usages - lruStamp + (stamp -> k)
70+
val newCache = Cache(max, newEntries, newUsages, stamp + 1)
71+
(newCache, Some(lruKey -> entries(lruKey)))
72+
} else {
73+
// not growing past max size at this point, so only need to evict if
74+
// the new entry is replacing an existing entry with different value
75+
val newEntries = entries + (k -> v)
76+
val newUsages = usages + (stamp -> k)
77+
val newCache = Cache(max, newEntries, newUsages, stamp + 1)
78+
val evicted = entries.get(k).filter(_ != v).map(k -> _)
79+
(newCache, evicted)
80+
}
81+
82+
def values: Iterable[V] = entries.values
83+
84+
override def toString: String =
85+
usages.entries.iterator.map { case (_, k) => s"$k -> ${entries(k)}" }.mkString("Cache(", ", ", ")")
86+
}
87+
88+
object Cache {
89+
private def apply[K, V](max: Int, entries: Map[K, V], usages: SortedBiMap[Long, K], stamp: Long): Cache[K, V] =
90+
new Cache(max, entries)(usages, stamp) {}
91+
92+
def empty[K, V](max: Int): Cache[K, V] =
93+
apply(max max 0, Map.empty, SortedBiMap.empty, 0L)
94+
}
95+
96+

modules/core/shared/src/main/scala/data/SemispaceCache.scala

Lines changed: 0 additions & 83 deletions
This file was deleted.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright (c) 2018-2024 by Rob Norris and Contributors
2+
// This software is licensed under the MIT License (MIT).
3+
// For more information see LICENSE or https://opensource.org/licenses/MIT
4+
5+
package skunk.data
6+
7+
import scala.collection.immutable.SortedMap
8+
import scala.math.Ordering
9+
10+
/** Immutable bi-directional map that is sorted by key. */
11+
sealed abstract case class SortedBiMap[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]) {
12+
assert(entries.size == inverse.size)
13+
14+
def size: Int = entries.size
15+
16+
def head: (K, V) = entries.head
17+
18+
def get(k: K): Option[V] = entries.get(k)
19+
20+
def put(k: K, v: V): SortedBiMap[K, V] =
21+
// nb: couple important properties here:
22+
// - SortedBiMap(k0 -> v, v -> k0).put(k1, v) == SortedBiMap(k1 -> v, v -> k1)
23+
// - SortedBiMap(k -> v0, v0 -> k).put(k, v1) == SortedBiMap(k -> v1, v1 -> k)
24+
SortedBiMap(
25+
inverse.get(v).fold(entries)(entries - _) + (k -> v),
26+
entries.get(k).fold(inverse)(inverse - _) + (v -> k))
27+
28+
def +(kv: (K, V)): SortedBiMap[K, V] = put(kv._1, kv._2)
29+
30+
def -(k: K): SortedBiMap[K, V] =
31+
get(k) match {
32+
case Some(v) => SortedBiMap(entries - k, inverse - v)
33+
case None => this
34+
}
35+
36+
def inverseGet(v: V): Option[K] = inverse.get(v)
37+
38+
override def toString: String =
39+
entries.iterator.map { case (k, v) => s"$k <-> $v" }.mkString("SortedBiMap(", ", ", ")")
40+
}
41+
42+
object SortedBiMap {
43+
private def apply[K: Ordering, V](entries: SortedMap[K, V], inverse: Map[V, K]): SortedBiMap[K, V] =
44+
new SortedBiMap[K, V](entries, inverse) {}
45+
46+
def empty[K: Ordering, V]: SortedBiMap[K, V] = apply(SortedMap.empty, Map.empty)
47+
}
48+

modules/core/shared/src/main/scala/util/StatementCache.scala

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import cats.{ Functor, ~> }
88
import cats.syntax.all._
99
import skunk.Statement
1010
import cats.effect.kernel.Ref
11-
import skunk.data.SemispaceCache
11+
import skunk.data.Cache
1212

1313
/** An LRU (by access) cache, keyed by statement `CacheKey`. */
1414
sealed trait StatementCache[F[_], V] { outer =>
@@ -35,31 +35,42 @@ sealed trait StatementCache[F[_], V] { outer =>
3535
object StatementCache {
3636

3737
def empty[F[_]: Functor: Ref.Make, V](max: Int, trackEviction: Boolean): F[StatementCache[F, V]] =
38-
Ref[F].of(SemispaceCache.empty[Statement.CacheKey, V](max, trackEviction)).map { ref =>
38+
// State is the cache and a set of evicted values; the evicted set only grows when trackEviction is true
39+
Ref[F].of((Cache.empty[Statement.CacheKey, V](max), Set.empty[V])).map { ref =>
3940
new StatementCache[F, V] {
4041

4142
def get(k: Statement[_]): F[Option[V]] =
42-
ref.modify { c =>
43-
c.lookup(k.cacheKey) match {
44-
case Some((cʹ, v)) => (cʹ, Some(v))
45-
case None => (c, None)
43+
ref.modify { case (c, evicted) =>
44+
c.get(k.cacheKey) match {
45+
case Some((cʹ, v)) => (cʹ -> evicted, Some(v))
46+
case None => (c -> evicted, None)
4647
}
4748
}
4849

4950
def put(k: Statement[_], v: V): F[Unit] =
50-
ref.update(_.insert(k.cacheKey, v))
51+
ref.update { case (c, evicted) =>
52+
val (c2, e) = c.put(k.cacheKey, v)
53+
// Remove the value we just inserted from the evicted set and add the newly evicted value, if any
54+
val evicted2 = e.filter(_ => trackEviction).fold(evicted - v) { case (_, ev) => evicted - v + ev }
55+
(c2, evicted2)
56+
}
5157

5258
def containsKey(k: Statement[_]): F[Boolean] =
53-
ref.get.map(_.containsKey(k.cacheKey))
59+
ref.get.map(_._1.contains(k.cacheKey))
5460

5561
def clear: F[Unit] =
56-
ref.update(_.evictAll)
62+
ref.update { case (c, evicted) =>
63+
val evicted2 = if (trackEviction) evicted ++ c.values else evicted
64+
(Cache.empty[Statement.CacheKey, V](max), evicted2)
65+
}
5766

5867
def values: F[List[V]] =
59-
ref.get.map(_.values)
68+
ref.get.map(_._1.values.toList)
6069

61-
def clearEvicted: F[List[V]] =
62-
ref.modify(_.clearEvicted)
70+
def clearEvicted: F[List[V]] =
71+
ref.modify { case (c, evicted) =>
72+
(c, Set.empty[V]) -> evicted.toList
73+
}
6374
}
6475
}
6576
}

modules/tests/shared/src/test/scala/PrepareCacheTest.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import skunk.implicits._
88
import skunk.codec.numeric.int8
99
import skunk.codec.text
1010
import skunk.codec.boolean
11-
import cats.syntax.all.*
11+
import cats.syntax.all._
1212

1313
class PrepareCacheTest extends SkunkTest {
1414

@@ -17,16 +17,8 @@ class PrepareCacheTest extends SkunkTest {
1717
private val pgStatementsCountByStatement = sql"select count(*) from pg_prepared_statements where statement = ${text.text}".query(int8)
1818
private val pgStatementsCount = sql"select count(*) from pg_prepared_statements".query(int8)
1919
private val pgStatements = sql"select statement from pg_prepared_statements order by prepare_time".query(text.text)
20-
21-
pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p =>
22-
List.fill(4)(
23-
p.use { s =>
24-
s.execute(pgStatementsByName)("foo").void >> s.execute(pgStatementsByStatement)("bar").void >> s.execute(pgStatementsCountByStatement)("baz").void
25-
}
26-
).sequence
27-
}
28-
29-
pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 1) { p =>
20+
21+
pooledTest("prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 2) { p =>
3022
p.use { s =>
3123
s.execute(pgStatementsByName)("foo").void >>
3224
s.execute(pgStatementsByStatement)("bar").void >>
@@ -49,7 +41,7 @@ class PrepareCacheTest extends SkunkTest {
4941
}
5042
}
5143

52-
pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 1) { p =>
44+
pooledTest("prepared statements via prepare shouldn't get evicted until they go out of scope", max = 1, parseCacheSize = 2) { p =>
5345
p.use { s =>
5446
// creates entry in cache
5547
s.prepare(pgStatementsByName)
@@ -97,4 +89,14 @@ class PrepareCacheTest extends SkunkTest {
9789
}
9890
}
9991
}
92+
93+
pooledTest("concurrent prepare cache should close evicted prepared statements at end of session", max = 1, parseCacheSize = 4) { p =>
94+
List.fill(8)(
95+
p.use { s =>
96+
s.execute(pgStatementsByName)("foo").void >>
97+
s.execute(pgStatementsByStatement)("bar").void >>
98+
s.execute(pgStatementsCountByStatement)("baz").void
99+
}
100+
).sequence
101+
}
100102
}

0 commit comments

Comments
 (0)