From 44cb670d61132739ab0e77cd82f969052ab63a4e Mon Sep 17 00:00:00 2001 From: the123saurav <42906612+the123saurav@users.noreply.github.com> Date: Fri, 17 Dec 2021 12:20:36 +0530 Subject: [PATCH 1/2] Implement ZAdd and Zscore (#79) --- .../java/dev/keva/core/aof/AOFContainer.java | 1 + .../dev/keva/core/command/impl/zset/ZAdd.java | 123 +++++++++++ .../keva/core/command/impl/zset/ZScore.java | 36 ++++ .../java/dev/keva/core/server/AOFTest.java | 2 +- .../keva/core/server/AbstractServerTest.java | 98 +++++++++ docs/src/guide/overview/commands.md | 8 + .../keva/protocol/resp/reply/BulkReply.java | 10 +- .../keva/protocol/resp/reply/ErrorReply.java | 7 + .../java/dev/keva/store/KevaDatabase.java | 8 + .../keva/store/impl/OffHeapDatabaseImpl.java | 113 +++++++++- .../keva/store/impl/OnHeapDatabaseImpl.java | 112 +++++++++- .../main/java/dev/keva/store/type/ZSet.java | 203 ++++++++++++++++++ .../main/java/dev/keva/util/Constants.java | 14 ++ .../main/java/dev/keva/util/DoubleUtil.java | 16 ++ 14 files changed, 742 insertions(+), 9 deletions(-) create mode 100644 core/src/main/java/dev/keva/core/command/impl/zset/ZAdd.java create mode 100644 core/src/main/java/dev/keva/core/command/impl/zset/ZScore.java create mode 100644 store/src/main/java/dev/keva/store/type/ZSet.java create mode 100644 util/src/main/java/dev/keva/util/Constants.java create mode 100644 util/src/main/java/dev/keva/util/DoubleUtil.java diff --git a/core/src/main/java/dev/keva/core/aof/AOFContainer.java b/core/src/main/java/dev/keva/core/aof/AOFContainer.java index de696e2a..e8721b67 100644 --- a/core/src/main/java/dev/keva/core/aof/AOFContainer.java +++ b/core/src/main/java/dev/keva/core/aof/AOFContainer.java @@ -114,6 +114,7 @@ public List read() throws IOException { byte[][] objects = (byte[][]) input.readObject(); commands.add(Command.newInstance(objects, false)); } catch (EOFException e) { + log.error("Error while reading AOF command", e); fis.close(); return commands; } catch (ClassNotFoundException e) { diff --git a/core/src/main/java/dev/keva/core/command/impl/zset/ZAdd.java b/core/src/main/java/dev/keva/core/command/impl/zset/ZAdd.java new file mode 100644 index 00000000..f58c8790 --- /dev/null +++ b/core/src/main/java/dev/keva/core/command/impl/zset/ZAdd.java @@ -0,0 +1,123 @@ +package dev.keva.core.command.impl.zset; + +import dev.keva.core.command.annotation.CommandImpl; +import dev.keva.core.command.annotation.Execute; +import dev.keva.core.command.annotation.Mutate; +import dev.keva.core.command.annotation.ParamLength; +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.BulkReply; +import dev.keva.protocol.resp.reply.ErrorReply; +import dev.keva.protocol.resp.reply.IntegerReply; +import dev.keva.protocol.resp.reply.Reply; +import dev.keva.store.KevaDatabase; +import dev.keva.util.DoubleUtil; +import dev.keva.util.hashbytes.BytesKey; + +import java.nio.charset.StandardCharsets; +import java.util.AbstractMap.SimpleEntry; + +import static dev.keva.util.Constants.FLAG_CH; +import static dev.keva.util.Constants.FLAG_GT; +import static dev.keva.util.Constants.FLAG_INCR; +import static dev.keva.util.Constants.FLAG_LT; +import static dev.keva.util.Constants.FLAG_NX; +import static dev.keva.util.Constants.FLAG_XX; + +@Component +@CommandImpl("zadd") +@ParamLength(type = ParamLength.Type.AT_LEAST, value = 3) +@Mutate +public final class ZAdd { + private static final String XX = "xx"; + private static final String NX = "nx"; + private static final String GT = "gt"; + private static final String LT = "lt"; + private static final String INCR = "incr"; + private static final String CH = "ch"; + + private final KevaDatabase database; + + @Autowired + public ZAdd(KevaDatabase database) { + this.database = database; + } + + @Execute + public Reply execute(byte[][] params) { + // Parse the flags, if any + boolean xx = false, nx = false, gt = false, lt = false, incr = false; + int argPos = 1, flags = 0; + String arg; + while (argPos < params.length) { + arg = new String(params[argPos], StandardCharsets.UTF_8); + if (XX.equalsIgnoreCase(arg)) { + xx = true; + flags |= FLAG_XX; + } else if (NX.equalsIgnoreCase(arg)) { + nx = true; + flags |= FLAG_NX; + } else if (GT.equalsIgnoreCase(arg)) { + gt = true; + flags |= FLAG_GT; + } else if (LT.equalsIgnoreCase(arg)) { + lt = true; + flags |= FLAG_LT; + } else if (INCR.equalsIgnoreCase(arg)) { + incr = true; + flags |= FLAG_INCR; + } else if (CH.equalsIgnoreCase(arg)) { + flags |= FLAG_CH; + } else { + break; + } + ++argPos; + } + + int numMembers = params.length - argPos; + if (numMembers % 2 != 0) { + return ErrorReply.SYNTAX_ERROR; + } + numMembers /= 2; + + if (nx && xx) { + return ErrorReply.ZADD_NX_XX_ERROR; + } + if ((gt && nx) || (lt && nx) || (gt && lt)) { + return ErrorReply.ZADD_GT_LT_NX_ERROR; + } + if (incr && numMembers > 1) { + return ErrorReply.ZADD_INCR_ERROR; + } + + // Parse the key and value + final SimpleEntry[] members = new SimpleEntry[numMembers]; + double score; + String rawScore; + for (int memberIndex = 0; memberIndex < numMembers; ++memberIndex) { + try { + rawScore = new String(params[argPos++], StandardCharsets.UTF_8); + if (rawScore.equalsIgnoreCase("inf") || rawScore.equalsIgnoreCase("infinity") + || rawScore.equalsIgnoreCase("+inf") || rawScore.equalsIgnoreCase("+infinity") + ) { + score = Double.POSITIVE_INFINITY; + } else if (rawScore.equalsIgnoreCase("-inf") || rawScore.equalsIgnoreCase("-infinity")) { + score = Double.NEGATIVE_INFINITY; + } else { + score = Double.parseDouble(rawScore); + } + } catch (final NumberFormatException ignored) { + // return on first bad input + return ErrorReply.ZADD_SCORE_FLOAT_ERROR; + } + members[memberIndex] = new SimpleEntry<>(score, new BytesKey(params[argPos++])); + } + + if (incr) { + Double result = database.zincrby(params[0], members[0].getKey(), members[0].getValue(), flags); + return result == null ? BulkReply.NIL_REPLY : new BulkReply(DoubleUtil.toString(result)); + } + int result = database.zadd(params[0], members, flags); + return new IntegerReply(result); + } +} diff --git a/core/src/main/java/dev/keva/core/command/impl/zset/ZScore.java b/core/src/main/java/dev/keva/core/command/impl/zset/ZScore.java new file mode 100644 index 00000000..fae864de --- /dev/null +++ b/core/src/main/java/dev/keva/core/command/impl/zset/ZScore.java @@ -0,0 +1,36 @@ +package dev.keva.core.command.impl.zset; + +import dev.keva.core.command.annotation.CommandImpl; +import dev.keva.core.command.annotation.Execute; +import dev.keva.core.command.annotation.ParamLength; +import dev.keva.ioc.annotation.Autowired; +import dev.keva.ioc.annotation.Component; +import dev.keva.protocol.resp.reply.BulkReply; +import dev.keva.store.KevaDatabase; + +@Component +@CommandImpl("zscore") +@ParamLength(type = ParamLength.Type.EXACT, value = 2) +public final class ZScore { + private final KevaDatabase database; + + @Autowired + public ZScore(KevaDatabase database) { + this.database = database; + } + + @Execute + public BulkReply execute(byte[] key, byte[] member) { + final Double result = database.zscore(key, member); + if(result == null){ + return BulkReply.NIL_REPLY; + } + if (result.equals(Double.POSITIVE_INFINITY)) { + return BulkReply.POSITIVE_INFINITY_REPLY; + } + if (result.equals(Double.NEGATIVE_INFINITY)) { + return BulkReply.NEGATIVE_INFINITY_REPLY; + } + return new BulkReply(result.toString()); + } +} diff --git a/core/src/test/java/dev/keva/core/server/AOFTest.java b/core/src/test/java/dev/keva/core/server/AOFTest.java index 665725af..3f559226 100644 --- a/core/src/test/java/dev/keva/core/server/AOFTest.java +++ b/core/src/test/java/dev/keva/core/server/AOFTest.java @@ -24,7 +24,7 @@ Server startServer(int port) throws Exception { .persistence(false) .aof(true) .aofInterval(1000) - .workDirectory("./") + .workDirectory(System.getProperty("java.io.tmpdir")) .build(); val server = KevaServer.of(config); new Thread(() -> { diff --git a/core/src/test/java/dev/keva/core/server/AbstractServerTest.java b/core/src/test/java/dev/keva/core/server/AbstractServerTest.java index f7f1cbc4..c0f71cfc 100644 --- a/core/src/test/java/dev/keva/core/server/AbstractServerTest.java +++ b/core/src/test/java/dev/keva/core/server/AbstractServerTest.java @@ -9,10 +9,13 @@ import java.util.Arrays; import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import lombok.var; +import redis.clients.jedis.params.ZAddParams; import static org.junit.jupiter.api.Assertions.*; @@ -827,6 +830,101 @@ void setrange() { } } + @Test + void zaddWithXXAndNXErrs() { + assertThrows(JedisDataException.class, () -> { + jedis.zadd("zset", 1.0, "val", new ZAddParams().xx().nx()); + }); + } + + @Test + void zaddSingleWithNxAndGtErrs() { + assertThrows(JedisDataException.class, () -> { + jedis.zadd("zset", 1.0, "val", new ZAddParams().gt().nx()); + }); + } + + @Test + void zaddSingleWithNxAndLtErrs() { + assertThrows(JedisDataException.class, () -> { + jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().nx()); + }); + } + + @Test + void zaddSingleWithGtAndLtErrs() { + assertThrows(JedisDataException.class, () -> { + jedis.zadd("zset", 1.0, "val", new ZAddParams().lt().gt()); + }); + } + + @Test + void zaddSingleWithoutOptions() { + try { + var result = jedis.zadd("zset", 1.0, "val"); + assertEquals(1, result); + + result = jedis.zadd("zset", 1.0, "val"); + assertEquals(0, result); + } catch (Exception e) { + fail(e); + } + } + + @Test + void zaddMultipleWithoutOptions() { + try { + Map members = new HashMap<>(); + int numMembers = 100; + for(int i=0; i +
+ SortedSet + +- ZADD +- ZSCORE + +
+
Pub/Sub diff --git a/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/BulkReply.java b/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/BulkReply.java index 3018a572..7d2b90db 100644 --- a/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/BulkReply.java +++ b/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/BulkReply.java @@ -11,6 +11,8 @@ public class BulkReply implements Reply { public static final BulkReply NIL_REPLY = new BulkReply(); + public static final BulkReply POSITIVE_INFINITY_REPLY = new BulkReply("inf"); + public static final BulkReply NEGATIVE_INFINITY_REPLY = new BulkReply("-inf"); public static final char MARKER = '$'; private final ByteBuf bytes; @@ -22,11 +24,7 @@ private BulkReply() { } public BulkReply(byte[] bytes) { - if (bytes.length == 0) { - this.bytes = Unpooled.EMPTY_BUFFER; - } else { - this.bytes = Unpooled.wrappedBuffer(bytes); - } + this.bytes = Unpooled.wrappedBuffer(bytes); capacity = bytes.length; } @@ -59,7 +57,7 @@ public void write(ByteBuf os) throws IOException { os.writeByte(MARKER); os.writeBytes(numToBytes(capacity, true)); if (capacity > 0) { - os.writeBytes(bytes); + os.writeBytes(bytes.array()); os.writeBytes(CRLF); } if (capacity == 0) { diff --git a/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/ErrorReply.java b/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/ErrorReply.java index b4d927c1..e723e498 100644 --- a/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/ErrorReply.java +++ b/resp-protocol/src/main/java/dev/keva/protocol/resp/reply/ErrorReply.java @@ -7,6 +7,13 @@ public class ErrorReply implements Reply { public static final char MARKER = '-'; + // Pre-defined errors + public static final ErrorReply SYNTAX_ERROR = new ErrorReply("ERR syntax error"); + public static final ErrorReply ZADD_NX_XX_ERROR = new ErrorReply("ERR XX and NX options at the same time are not compatible"); + public static final ErrorReply ZADD_GT_LT_NX_ERROR = new ErrorReply("GT, LT, and/or NX options at the same time are not compatible"); + public static final ErrorReply ZADD_INCR_ERROR = new ErrorReply("INCR option supports a single increment-element pair"); + public static final ErrorReply ZADD_SCORE_FLOAT_ERROR = new ErrorReply("value is not a valid float"); + private final String error; public ErrorReply(String error) { diff --git a/store/src/main/java/dev/keva/store/KevaDatabase.java b/store/src/main/java/dev/keva/store/KevaDatabase.java index a5cedb2d..d18bb0bd 100644 --- a/store/src/main/java/dev/keva/store/KevaDatabase.java +++ b/store/src/main/java/dev/keva/store/KevaDatabase.java @@ -1,5 +1,8 @@ package dev.keva.store; +import dev.keva.util.hashbytes.BytesKey; + +import java.util.AbstractMap; import java.util.concurrent.locks.Lock; public interface KevaDatabase { @@ -69,4 +72,9 @@ public interface KevaDatabase { byte[][] mget(byte[]... keys); + int zadd(byte[] key, AbstractMap.SimpleEntry[] members, int flags); + + Double zincrby(byte[] key, Double score, BytesKey e, int flags); + + Double zscore(byte[] key, byte[] member); } diff --git a/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java b/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java index d1cb266a..7412ffc8 100644 --- a/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java +++ b/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java @@ -1,11 +1,13 @@ package dev.keva.store.impl; +import dev.keva.store.type.ZSet; import dev.keva.util.hashbytes.BytesKey; import dev.keva.util.hashbytes.BytesValue; import dev.keva.store.DatabaseConfig; import dev.keva.store.KevaDatabase; import dev.keva.store.lock.SpinLock; import lombok.Getter; +import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; import lombok.var; @@ -16,9 +18,22 @@ import java.io.File; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; import java.util.concurrent.locks.Lock; +import static dev.keva.util.Constants.FLAG_CH; +import static dev.keva.util.Constants.FLAG_GT; +import static dev.keva.util.Constants.FLAG_LT; +import static dev.keva.util.Constants.FLAG_NX; +import static dev.keva.util.Constants.FLAG_XX; + @Slf4j public class OffHeapDatabaseImpl implements KevaDatabase { @Getter @@ -716,4 +731,100 @@ public byte[][] mget(byte[]... keys) { lock.unlock(); } } + + @Override + public int zadd(final byte[] key, @NonNull final AbstractMap.SimpleEntry[] members, final int flags) { + boolean xx = (flags & FLAG_XX) != 0; + boolean nx = (flags & FLAG_NX) != 0; + boolean lt = (flags & FLAG_LT) != 0; + boolean gt = (flags & FLAG_GT) != 0; + boolean ch = (flags & FLAG_CH) != 0; + + // Track both to eliminate conditional branch + int added = 0, changed = 0; + + lock.lock(); + try { + byte[] value = chronicleMap.get(key); + ZSet zSet; + zSet = value == null ? new ZSet() : (ZSet) SerializationUtils.deserialize(value); + for (AbstractMap.SimpleEntry member : members) { + Double currScore = zSet.getScore(member.getValue()); + if (currScore == null) { + if (xx) { + continue; + } + currScore = member.getKey(); + zSet.add(new AbstractMap.SimpleEntry<>(currScore, member.getValue())); + ++added; + ++changed; + continue; + } + Double newScore = member.getKey(); + if(nx || (lt && newScore >= currScore) || (gt && newScore <= currScore)) { + continue; + } + if (!newScore.equals(currScore)) { + zSet.removeByKey(member.getValue()); + zSet.add(member); + ++changed; + } + } + chronicleMap.put(key, SerializationUtils.serialize(zSet)); + return ch ? changed : added; + } finally { + lock.unlock(); + } + } + + @Override + public Double zincrby(byte[] key, Double incr, BytesKey e, int flags) { + lock.lock(); + try { + byte[] value = chronicleMap.get(key); + ZSet zSet; + zSet = value == null ? new ZSet() : (ZSet) SerializationUtils.deserialize(value); + Double currentScore = zSet.getScore(e); + if (currentScore == null) { + if ((flags & FLAG_XX) != 0) { + return null; + } + currentScore = incr; + zSet.add(new AbstractMap.SimpleEntry<>(currentScore, e)); + chronicleMap.put(key, SerializationUtils.serialize(zSet)); + return currentScore; + } + if ((flags & FLAG_NX) != 0) { + return null; + } + if ((flags & FLAG_LT) != 0 && (incr >= 0 || currentScore.isInfinite())) { + return null; + } + if ((flags & FLAG_GT) != 0 && (incr <= 0 || currentScore.isInfinite())) { + return null; + } + zSet.remove(new AbstractMap.SimpleEntry<>(currentScore, e)); + currentScore += incr; + zSet.add(new AbstractMap.SimpleEntry<>(currentScore, e)); + chronicleMap.put(key, SerializationUtils.serialize(zSet)); + return currentScore; + } finally { + lock.unlock(); + } + } + + @Override + public Double zscore(byte[] key, byte[] member) { + lock.lock(); + try { + byte[] value = chronicleMap.get(key); + if (value == null) { + return null; + } + ZSet zset = (ZSet) SerializationUtils.deserialize(value); + return zset.getScore(new BytesKey(member)); + } finally { + lock.unlock(); + } + } } diff --git a/store/src/main/java/dev/keva/store/impl/OnHeapDatabaseImpl.java b/store/src/main/java/dev/keva/store/impl/OnHeapDatabaseImpl.java index 75f3ca81..e28b1c80 100644 --- a/store/src/main/java/dev/keva/store/impl/OnHeapDatabaseImpl.java +++ b/store/src/main/java/dev/keva/store/impl/OnHeapDatabaseImpl.java @@ -1,5 +1,6 @@ package dev.keva.store.impl; +import dev.keva.store.type.ZSet; import dev.keva.util.hashbytes.BytesKey; import dev.keva.util.hashbytes.BytesValue; import dev.keva.store.KevaDatabase; @@ -10,9 +11,19 @@ import org.apache.commons.lang3.SerializationUtils; import java.nio.charset.StandardCharsets; -import java.util.*; +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; import java.util.concurrent.locks.Lock; +import static dev.keva.util.Constants.*; +import static dev.keva.util.Constants.FLAG_GT; + public class OnHeapDatabaseImpl implements KevaDatabase { @Getter private final Lock lock = new SpinLock(); @@ -674,4 +685,103 @@ public byte[][] mget(byte[]... keys) { lock.unlock(); } } + + @Override + public int zadd(byte[] key, AbstractMap.SimpleEntry[] members, int flags) { + boolean xx = (flags & FLAG_XX) != 0; + boolean nx = (flags & FLAG_NX) != 0; + boolean lt = (flags & FLAG_LT) != 0; + boolean gt = (flags & FLAG_GT) != 0; + boolean ch = (flags & FLAG_CH) != 0; + + // Track both to eliminate conditional branch + int added = 0, changed = 0; + + lock.lock(); + try { + final BytesKey mapKey = new BytesKey(key); + byte[] value = map.get(mapKey).getBytes(); + ZSet zSet; + zSet = value == null ? new ZSet() : (ZSet) SerializationUtils.deserialize(value); + for (AbstractMap.SimpleEntry member : members) { + Double currScore = zSet.getScore(member.getValue()); + if (currScore == null) { + if (xx) { + continue; + } + currScore = member.getKey(); + zSet.add(new AbstractMap.SimpleEntry<>(currScore, member.getValue())); + ++added; + ++changed; + continue; + } + Double newScore = member.getKey(); + if(nx || (lt && newScore >= currScore) || (gt && newScore <= currScore)) { + continue; + } + if (!newScore.equals(currScore)) { + zSet.removeByKey(member.getValue()); + zSet.add(member); + ++changed; + } + } + map.put(mapKey, new BytesValue(SerializationUtils.serialize(zSet))); + return ch ? changed : added; + } finally { + lock.unlock(); + } + } + + @Override + public Double zincrby(byte[] key, Double incr, BytesKey e, int flags) { + lock.lock(); + try { + final BytesKey mapKey = new BytesKey(key); + byte[] value = map.get(mapKey).getBytes(); + ZSet zSet; + zSet = value == null ? new ZSet() : (ZSet) SerializationUtils.deserialize(value); + Double currentScore = zSet.getScore(e); + if (currentScore == null) { + if ((flags & FLAG_XX) != 0) { + return null; + } + currentScore = incr; + zSet.add(new AbstractMap.SimpleEntry<>(currentScore, e)); + map.put(mapKey, new BytesValue(SerializationUtils.serialize(zSet))); + return currentScore; + } + if ((flags & FLAG_NX) != 0) { + return null; + } + if ((flags & FLAG_LT) != 0 && incr >= 0) { + return null; + } + if ((flags & FLAG_GT) != 0 && incr <= 0) { + return null; + } + zSet.remove(new AbstractMap.SimpleEntry<>(currentScore, e)); + currentScore += incr; + zSet.add(new AbstractMap.SimpleEntry<>(currentScore, e)); + map.put(mapKey, new BytesValue(SerializationUtils.serialize(zSet))); + return currentScore; + } finally { + lock.unlock(); + } + + } + + @Override + public Double zscore(byte[] key, byte[] member) { + lock.lock(); + try { + byte[] value = map.get(new BytesKey(key)).getBytes(); + if (value == null) { + return null; + } + ZSet zset = (ZSet) SerializationUtils.deserialize(value); + return zset.getScore(new BytesKey(member)); + } finally { + lock.unlock(); + } + } } diff --git a/store/src/main/java/dev/keva/store/type/ZSet.java b/store/src/main/java/dev/keva/store/type/ZSet.java new file mode 100644 index 00000000..ff55cb1a --- /dev/null +++ b/store/src/main/java/dev/keva/store/type/ZSet.java @@ -0,0 +1,203 @@ +package dev.keva.store.type; + +import dev.keva.util.hashbytes.BytesKey; +import lombok.NonNull; + +import java.io.Serializable; +import java.util.AbstractMap.SimpleEntry; +import java.util.AbstractSet; +import java.util.Comparator; +import java.util.HashMap; +import java.util.Iterator; +import java.util.NavigableSet; +import java.util.TreeSet; + +/** + * A SortedSet implementation tailor made for Redis, and hence no generic definition. + * The current implementation uses TreeSet which internally used Balanced BST. + * In the future, if needed, we can implement a SkipList. + */ +public class ZSet extends AbstractSet> implements NavigableSet>, Serializable { + + private static final long serialVersionUID = 1L; + + private final HashMap keys = new HashMap<>(); + + private final TreeSet> scores = new TreeSet<>((Comparator> & Serializable)(e1, e2) -> { + int cmp = e1.getKey().compareTo(e2.getKey()); + if (cmp != 0) { + return cmp; + } + return e1.getValue().compareTo(e2.getValue()); + }); + + @Override + public SimpleEntry lower(SimpleEntry entry) { + return scores.lower(entry); + } + + @Override + public SimpleEntry floor(SimpleEntry entry) { + return scores.floor(entry); + } + + @Override + public SimpleEntry ceiling(SimpleEntry entry) { + return scores.ceiling(entry); + } + + @Override + public SimpleEntry higher(SimpleEntry entry) { + return scores.higher(entry); + } + + @Override + public SimpleEntry pollFirst() { + return scores.pollFirst(); + } + + @Override + public SimpleEntry pollLast() { + return scores.pollLast(); + } + + @Override + public int size() { + return keys.size(); + } + + @Override + public boolean isEmpty() { + return keys.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return scores.contains(o); + } + + @Override + public Iterator> iterator() { + return scores.iterator(); + } + + @Override + public Object[] toArray() { + return scores.toArray(); + } + + @Override + public T[] toArray(@NonNull T[] ts) { + return scores.toArray(ts); + } + + @Override + public boolean add(SimpleEntry entry) { + boolean result = true; + if (keys.containsKey(entry.getValue())){ + result = false; + scores.remove(new SimpleEntry<>(keys.get(entry.getValue()), entry.getValue())); + } + scores.add(new SimpleEntry<>(entry.getKey(), entry.getValue())); + keys.put(entry.getValue(), entry.getKey()); + return result; + } + + @Override + public boolean remove(@NonNull Object o) { + SimpleEntry entry = (SimpleEntry) o; + if (keys.containsKey(entry.getValue())){ + scores.remove(new SimpleEntry<>(entry.getKey(), entry.getValue())); + keys.remove(entry.getValue()); + return true; + } + return false; + } + + public boolean removeByKey(@NonNull BytesKey key) { + if (keys.containsKey(key)) { + scores.remove(new SimpleEntry<>(keys.get(key), key)); + keys.remove(key); + return true; + } + return false; + } + + @Override + public synchronized void clear() { + scores.clear(); + keys.clear(); + } + + @NonNull + @Override + public NavigableSet> descendingSet() { + return scores.descendingSet(); + } + + @NonNull + @Override + public Iterator> descendingIterator() { + return scores.descendingIterator(); + } + + @NonNull + @Override + public NavigableSet> subSet(SimpleEntry start, boolean b1, SimpleEntry end, boolean b2) { + return scores.subSet(start, b1, end, b2); + } + + @NonNull + @Override + public NavigableSet> headSet(SimpleEntry entry, boolean b) { + return scores.headSet(entry, b); + } + + @NonNull + @Override + public NavigableSet> tailSet(SimpleEntry entry, boolean b) { + return scores.tailSet(entry, b); + } + + @Override + public Comparator> comparator() { + return scores.comparator(); + } + + @NonNull + @Override + public java.util.SortedSet> subSet(SimpleEntry begin, SimpleEntry end) { + return scores.subSet(begin, end); + } + + @NonNull + @Override + public java.util.SortedSet> headSet(SimpleEntry entry) { + return scores.headSet(entry); + } + + @NonNull + @Override + public java.util.SortedSet> tailSet(SimpleEntry entry) { + return scores.tailSet(entry); + } + + @Override + public SimpleEntry first() { + return scores.first(); + } + + @Override + public SimpleEntry last() { + return scores.last(); + } + + public Double getScore(BytesKey key) { + return keys.get(key); + } + + @Override + public String toString() { + return keys.toString(); + } +} diff --git a/util/src/main/java/dev/keva/util/Constants.java b/util/src/main/java/dev/keva/util/Constants.java new file mode 100644 index 00000000..2ad4f815 --- /dev/null +++ b/util/src/main/java/dev/keva/util/Constants.java @@ -0,0 +1,14 @@ +package dev.keva.util; + +public final class Constants { + + public static final int FLAG_XX = 1; + public static final int FLAG_NX = 1 << 1; + public static final int FLAG_GT = 1 << 2; + public static final int FLAG_LT = 1 << 3; + public static final int FLAG_INCR = 1 << 4; + public static final int FLAG_CH = 1 << 5; + + private Constants() { + } +} diff --git a/util/src/main/java/dev/keva/util/DoubleUtil.java b/util/src/main/java/dev/keva/util/DoubleUtil.java new file mode 100644 index 00000000..fe194648 --- /dev/null +++ b/util/src/main/java/dev/keva/util/DoubleUtil.java @@ -0,0 +1,16 @@ +package dev.keva.util; + +public final class DoubleUtil { + + private DoubleUtil(){} + + public static String toString(Double d){ + if (d.equals(Double.POSITIVE_INFINITY)) { + return "inf"; + } + if (d.equals(Double.NEGATIVE_INFINITY)) { + return "-inf"; + } + return d.toString(); + } +} From f22c0268e6f50728612dba8cf026b9acd940e414 Mon Sep 17 00:00:00 2001 From: the123saurav Date: Sat, 18 Dec 2021 23:05:15 +0530 Subject: [PATCH 2/2] Initial commit for RW lock --- .../impl/key/manager/ExpirationManager.java | 7 +- .../manager/TransactionContext.java | 7 +- .../core/command/mapping/CommandMapper.java | 39 +++-- .../java/dev/keva/store/KevaDatabase.java | 3 +- .../keva/store/impl/OffHeapDatabaseImpl.java | 144 +++++++++--------- .../keva/store/impl/OnHeapDatabaseImpl.java | 138 ++++++++--------- .../java/dev/keva/store/lock/SpinLock.java | 24 ++- 7 files changed, 200 insertions(+), 162 deletions(-) diff --git a/core/src/main/java/dev/keva/core/command/impl/key/manager/ExpirationManager.java b/core/src/main/java/dev/keva/core/command/impl/key/manager/ExpirationManager.java index f15e9137..c21b782f 100644 --- a/core/src/main/java/dev/keva/core/command/impl/key/manager/ExpirationManager.java +++ b/core/src/main/java/dev/keva/core/command/impl/key/manager/ExpirationManager.java @@ -8,6 +8,7 @@ import dev.keva.ioc.annotation.Component; import dev.keva.protocol.resp.Command; import dev.keva.store.KevaDatabase; +import dev.keva.store.lock.SpinLock; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -71,14 +72,14 @@ public void executeExpire(byte[] key) { data[0] = "delete".getBytes(); data[1] = key; Command command = Command.newInstance(data, false); - Lock lock = database.getLock(); - lock.lock(); + SpinLock lock = database.getLock(); + lock.exclusiveLock(); try { aof.write(command); database.remove(key); clearExpiration(key); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } else { database.remove(key); diff --git a/core/src/main/java/dev/keva/core/command/impl/transaction/manager/TransactionContext.java b/core/src/main/java/dev/keva/core/command/impl/transaction/manager/TransactionContext.java index 2f5f5bdd..dbda205a 100644 --- a/core/src/main/java/dev/keva/core/command/impl/transaction/manager/TransactionContext.java +++ b/core/src/main/java/dev/keva/core/command/impl/transaction/manager/TransactionContext.java @@ -2,6 +2,7 @@ import dev.keva.core.command.mapping.CommandMapper; import dev.keva.protocol.resp.Command; +import dev.keva.store.lock.SpinLock; import dev.keva.util.hashbytes.BytesKey; import dev.keva.util.hashbytes.BytesValue; import dev.keva.protocol.resp.reply.MultiBulkReply; @@ -42,8 +43,8 @@ public void discard() { isQueuing = false; } - public Reply exec(ChannelHandlerContext ctx, Lock txLock) throws InterruptedException { - txLock.lock(); + public Reply exec(ChannelHandlerContext ctx, SpinLock txLock) throws InterruptedException { + txLock.exclusiveLock(); try { for (val watch : watchMap.entrySet()) { val key = watch.getKey(); @@ -74,7 +75,7 @@ public Reply exec(ChannelHandlerContext ctx, Lock txLock) throws InterruptedE return new MultiBulkReply(replies); } finally { - txLock.unlock(); + txLock.exclusiveUnlock(); } } } diff --git a/core/src/main/java/dev/keva/core/command/mapping/CommandMapper.java b/core/src/main/java/dev/keva/core/command/mapping/CommandMapper.java index c4d22910..17ec894f 100644 --- a/core/src/main/java/dev/keva/core/command/mapping/CommandMapper.java +++ b/core/src/main/java/dev/keva/core/command/mapping/CommandMapper.java @@ -25,11 +25,17 @@ import java.lang.reflect.InvocationTargetException; import java.util.Arrays; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; +import java.util.Set; @Component @Slf4j public class CommandMapper { + + private static final Set EXCLUSIVE_COMMANDS = new HashSet<>(Arrays.asList( + "exec", "expire", "expireat", "restore", "flushdb")); + @Getter private final Map methods = new HashMap<>(); @@ -98,17 +104,26 @@ public void init() { try { val lock = database.getLock(); - lock.lock(); + boolean locked = false, exclusive = false, writeToAOF = isAoF && isMutate; try { - if (ctx != null && isAoF && isMutate) { - try { - aof.write(command); - } catch (Exception e) { - log.error("Error writing to AOF", e); - } - } Object[] objects = new Object[types.length]; command.toArguments(objects, types, ctx); + if (ctx != null) { + locked = true; + if (isMutate || EXCLUSIVE_COMMANDS.contains(name)) { + lock.exclusiveLock(); + exclusive = true; + if (writeToAOF) { + try { + aof.write(command); + } catch (Exception e) { + log.error("Error writing to AOF", e); + } + } + } else { + lock.sharedLock(); + } + } // If not in AOF mode, then recycle(), // else, the command will be recycled in the AOF dump if (!kevaConfig.getAof()) { @@ -116,7 +131,13 @@ public void init() { } return (Reply) method.invoke(instance, objects); } finally { - lock.unlock(); + if (locked) { + if (exclusive) { + lock.exclusiveUnlock(); + } else { + lock.sharedUnlock(); + } + } } } catch (Exception e) { log.error(e.getMessage(), e); diff --git a/store/src/main/java/dev/keva/store/KevaDatabase.java b/store/src/main/java/dev/keva/store/KevaDatabase.java index d18bb0bd..5d0e01f5 100644 --- a/store/src/main/java/dev/keva/store/KevaDatabase.java +++ b/store/src/main/java/dev/keva/store/KevaDatabase.java @@ -1,12 +1,13 @@ package dev.keva.store; +import dev.keva.store.lock.SpinLock; import dev.keva.util.hashbytes.BytesKey; import java.util.AbstractMap; import java.util.concurrent.locks.Lock; public interface KevaDatabase { - Lock getLock(); + SpinLock getLock(); void flush(); diff --git a/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java b/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java index 7412ffc8..4f7bf089 100644 --- a/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java +++ b/store/src/main/java/dev/keva/store/impl/OffHeapDatabaseImpl.java @@ -27,6 +27,8 @@ import java.util.List; import java.util.Map; import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; import static dev.keva.util.Constants.FLAG_CH; import static dev.keva.util.Constants.FLAG_GT; @@ -37,7 +39,7 @@ @Slf4j public class OffHeapDatabaseImpl implements KevaDatabase { @Getter - private final Lock lock = new SpinLock(); + private final SpinLock lock = new SpinLock(); private ChronicleMap chronicleMap; @@ -65,47 +67,47 @@ public OffHeapDatabaseImpl(DatabaseConfig config) { @Override public void flush() { - lock.lock(); + lock.exclusiveLock(); try { chronicleMap.clear(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public byte[] get(byte[] key) { - lock.lock(); + lock.sharedLock(); try { return chronicleMap.get(key); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override public void put(byte[] key, byte[] val) { - lock.lock(); + lock.exclusiveLock(); try { chronicleMap.put(key, val); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public boolean remove(byte[] key) { - lock.lock(); + lock.exclusiveLock(); try { return chronicleMap.remove(key) != null; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public byte[] incrBy(byte[] key, long amount) { - lock.lock(); + lock.exclusiveLock(); try { return chronicleMap.compute(key, (k, oldVal) -> { long curVal = 0L; @@ -116,14 +118,14 @@ public byte[] incrBy(byte[] key, long amount) { return Long.toString(curVal).getBytes(StandardCharsets.UTF_8); }); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] hget(byte[] key, byte[] field) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -133,14 +135,14 @@ public byte[] hget(byte[] key, byte[] field) { BytesValue got = map.get(new BytesKey(field)); return got == null ? null : got.getBytes(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] hgetAll(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -155,14 +157,14 @@ public byte[][] hgetAll(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] hkeys(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -176,14 +178,14 @@ public byte[][] hkeys(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] hvals(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -197,14 +199,14 @@ public byte[][] hvals(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public void hset(byte[] key, byte[] field, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { chronicleMap.compute(key, (k, oldVal) -> { HashMap map; @@ -217,14 +219,14 @@ public void hset(byte[] key, byte[] field, byte[] value) { return SerializationUtils.serialize(map); }); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public boolean hdel(byte[] key, byte[] field) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -237,14 +239,14 @@ public boolean hdel(byte[] key, byte[] field) { } return result; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int lpush(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = chronicleMap.get(key); LinkedList list; @@ -255,14 +257,14 @@ public int lpush(byte[] key, byte[]... values) { chronicleMap.put(key, SerializationUtils.serialize(list)); return list.size(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int rpush(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = chronicleMap.get(key); LinkedList list; @@ -273,14 +275,14 @@ public int rpush(byte[] key, byte[]... values) { chronicleMap.put(key, SerializationUtils.serialize(list)); return list.size(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] lpop(byte[] key) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -294,14 +296,14 @@ public byte[] lpop(byte[] key) { chronicleMap.put(key, SerializationUtils.serialize(list)); return result; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] rpop(byte[] key) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -315,14 +317,14 @@ public byte[] rpop(byte[] key) { chronicleMap.put(key, SerializationUtils.serialize(list)); return result; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int llen(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -331,14 +333,14 @@ public int llen(byte[] key) { LinkedList list = (LinkedList) SerializationUtils.deserialize(value); return list.size(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] lrange(byte[] key, int start, int end) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -372,14 +374,14 @@ public byte[][] lrange(byte[] key, int start, int end) { } return result.toArray(new byte[0][0]); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] lindex(byte[] key, int index) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -394,14 +396,14 @@ public byte[] lindex(byte[] key, int index) { } return list.get(index).getBytes(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public void lset(byte[] key, int index, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value1 = chronicleMap.get(key); if (value1 == null) { @@ -417,14 +419,14 @@ public void lset(byte[] key, int index, byte[] value) { list.set(index, new BytesValue(value)); chronicleMap.put(key, SerializationUtils.serialize(list)); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int lrem(byte[] key, int count, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value1 = chronicleMap.get(key); if (value1 == null) { @@ -467,14 +469,14 @@ public int lrem(byte[] key, int count, byte[] value) { chronicleMap.put(key, SerializationUtils.serialize(list)); return result; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int sadd(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = chronicleMap.get(key); HashSet set; @@ -489,14 +491,14 @@ public int sadd(byte[] key, byte[]... values) { chronicleMap.put(key, SerializationUtils.serialize(set)); return count; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] smembers(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -510,14 +512,14 @@ public byte[][] smembers(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public boolean sismember(byte[] key, byte[] value) { - lock.lock(); + lock.sharedLock(); try { byte[] got = chronicleMap.get(key); if (got == null) { @@ -526,14 +528,14 @@ public boolean sismember(byte[] key, byte[] value) { HashSet set = (HashSet) SerializationUtils.deserialize(got); return set.contains(new BytesKey(value)); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public int scard(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -542,14 +544,14 @@ public int scard(byte[] key) { HashSet set = (HashSet) SerializationUtils.deserialize(value); return set.size(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] sdiff(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { HashSet set = new HashSet<>(); for (byte[] key : keys) { @@ -568,14 +570,14 @@ public byte[][] sdiff(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] sinter(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { HashSet set = new HashSet<>(); for (byte[] key : keys) { @@ -594,14 +596,14 @@ public byte[][] sinter(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] sunion(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { HashSet set = new HashSet<>(); for (byte[] key : keys) { @@ -618,14 +620,14 @@ public byte[][] sunion(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public int smove(byte[] source, byte[] destination, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { byte[] sourceValue = chronicleMap.get(source); if (sourceValue == null) { @@ -647,14 +649,14 @@ public int smove(byte[] source, byte[] destination, byte[] value) { } return 0; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int srem(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -674,13 +676,13 @@ public int srem(byte[] key, byte[]... values) { } return result; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public int strlen(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = chronicleMap.get(key); if (value == null) { @@ -688,13 +690,13 @@ public int strlen(byte[] key) { } return new String(value, StandardCharsets.UTF_8).length(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override public int setrange(byte[] key, byte[] offset, byte[] val) { - lock.lock(); + lock.exclusiveLock(); try { var offsetPosition = Integer.parseInt(new String(offset, StandardCharsets.UTF_8)); byte[] oldVal = chronicleMap.get(key); @@ -712,13 +714,13 @@ public int setrange(byte[] key, byte[] offset, byte[] val) { chronicleMap.put(key, newVal); return newValLength; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public byte[][] mget(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { byte[][] result = new byte[keys.length][]; for (int i = 0; i < keys.length; i++) { @@ -728,7 +730,7 @@ public byte[][] mget(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @@ -743,7 +745,7 @@ public int zadd(final byte[] key, @NonNull final AbstractMap.SimpleEntry map = new HashMap<>(100); @@ -37,39 +37,39 @@ public void flush() { @Override public void put(byte[] key, byte[] val) { - lock.lock(); + lock.exclusiveLock(); try { map.put(new BytesKey(key), new BytesValue(val)); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public byte[] get(byte[] key) { - lock.lock(); + lock.sharedLock(); try { BytesValue got = map.get(new BytesKey(key)); return got != null ? got.getBytes() : null; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override public boolean remove(byte[] key) { - lock.lock(); + lock.exclusiveLock(); try { BytesValue removed = map.remove(new BytesKey(key)); return removed != null; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public byte[] incrBy(byte[] key, long amount) { - lock.lock(); + lock.exclusiveLock(); try { return map.compute(new BytesKey(key), (k, oldVal) -> { long curVal = 0L; @@ -80,14 +80,14 @@ public byte[] incrBy(byte[] key, long amount) { return new BytesValue(Long.toString(curVal).getBytes(StandardCharsets.UTF_8)); }).getBytes(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] hget(byte[] key, byte[] field) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -97,14 +97,14 @@ public byte[] hget(byte[] key, byte[] field) { BytesValue got = map.get(new BytesKey(field)); return got != null ? got.getBytes() : null; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] hgetAll(byte[] key) { - lock.lock(); + lock.sharedLock(); try { BytesValue value = map.get(new BytesKey(key)); if (value == null) { @@ -119,14 +119,14 @@ public byte[][] hgetAll(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] hkeys(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -140,14 +140,14 @@ public byte[][] hkeys(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] hvals(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -161,14 +161,14 @@ public byte[][] hvals(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public void hset(byte[] key, byte[] field, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { map.compute(new BytesKey(key), (k, oldVal) -> { HashMap map; @@ -181,14 +181,14 @@ public void hset(byte[] key, byte[] field, byte[] value) { return new BytesValue(SerializationUtils.serialize(map)); }); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public boolean hdel(byte[] key, byte[] field) { - lock.lock(); + lock.exclusiveLock(); try { BytesValue value = map.get(new BytesKey(key)); if (value == null) { @@ -201,14 +201,14 @@ public boolean hdel(byte[] key, byte[] field) { } return removed; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int lpush(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); LinkedList list; @@ -219,14 +219,14 @@ public int lpush(byte[] key, byte[]... values) { map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(list))); return list.size(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int rpush(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); LinkedList list; @@ -237,14 +237,14 @@ public int rpush(byte[] key, byte[]... values) { map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(list))); return list.size(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] lpop(byte[] key) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); LinkedList list; @@ -256,14 +256,14 @@ public byte[] lpop(byte[] key) { map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(list))); return v.getBytes(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] rpop(byte[] key) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); LinkedList list; @@ -275,28 +275,28 @@ public byte[] rpop(byte[] key) { map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(list))); return v.getBytes(); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int llen(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); LinkedList list; list = value == null ? new LinkedList<>() : (LinkedList) SerializationUtils.deserialize(value); return list.size(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] lrange(byte[] key, int start, int end) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -330,14 +330,14 @@ public byte[][] lrange(byte[] key, int start, int end) { } return result.toArray(new byte[0][0]); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[] lindex(byte[] key, int index) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); LinkedList list; @@ -350,14 +350,14 @@ public byte[] lindex(byte[] key, int index) { } return list.get(index).getBytes(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public void lset(byte[] key, int index, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { byte[] v = map.get(new BytesKey(key)).getBytes(); LinkedList list; @@ -371,14 +371,14 @@ public void lset(byte[] key, int index, byte[] value) { list.set(index, new BytesValue(value)); map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(list))); } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int lrem(byte[] key, int count, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value1 = map.get(new BytesKey(key)).getBytes(); if (value1 == null) { @@ -421,14 +421,14 @@ public int lrem(byte[] key, int count, byte[] value) { map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(list))); return result; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int sadd(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); HashSet set; @@ -443,14 +443,14 @@ public int sadd(byte[] key, byte[]... values) { map.put(new BytesKey(key), new BytesValue(SerializationUtils.serialize(set))); return count; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] smembers(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -464,14 +464,14 @@ public byte[][] smembers(byte[] key) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public boolean sismember(byte[] key, byte[] value) { - lock.lock(); + lock.sharedLock(); try { byte[] got = map.get(new BytesKey(key)).getBytes(); if (got == null) { @@ -480,14 +480,14 @@ public boolean sismember(byte[] key, byte[] value) { HashSet set = (HashSet) SerializationUtils.deserialize(got); return set.contains(new BytesKey(value)); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public int scard(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -496,14 +496,14 @@ public int scard(byte[] key) { HashSet set = (HashSet) SerializationUtils.deserialize(value); return set.size(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] sdiff(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { HashSet set = new HashSet<>(); for (byte[] key : keys) { @@ -522,14 +522,14 @@ public byte[][] sdiff(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] sinter(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { HashSet set = new HashSet<>(); for (byte[] key : keys) { @@ -548,14 +548,14 @@ public byte[][] sinter(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public byte[][] sunion(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { HashSet set = new HashSet<>(); for (byte[] key : keys) { @@ -572,14 +572,14 @@ public byte[][] sunion(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override @SuppressWarnings("unchecked") public int smove(byte[] source, byte[] destination, byte[] value) { - lock.lock(); + lock.exclusiveLock(); try { byte[] sourceValue = map.get(new BytesKey(source)).getBytes(); if (sourceValue == null) { @@ -601,14 +601,14 @@ public int smove(byte[] source, byte[] destination, byte[] value) { } return 0; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override @SuppressWarnings("unchecked") public int srem(byte[] key, byte[]... values) { - lock.lock(); + lock.exclusiveLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -628,13 +628,13 @@ public int srem(byte[] key, byte[]... values) { } return count; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public int strlen(byte[] key) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -642,13 +642,13 @@ public int strlen(byte[] key) { } return new String(value, StandardCharsets.UTF_8).length(); } finally { - lock.unlock(); + lock.sharedUnlock(); } } @Override public int setrange(byte[] key, byte[] offset, byte[] val) { - lock.lock(); + lock.exclusiveLock(); try { var offsetPosition = Integer.parseInt(new String(offset, StandardCharsets.UTF_8)); byte[] oldVal = map.get(new BytesKey(key)).getBytes(); @@ -666,13 +666,13 @@ public int setrange(byte[] key, byte[] offset, byte[] val) { map.put(new BytesKey(key), new BytesValue(newVal)); return newValLength; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public byte[][] mget(byte[]... keys) { - lock.lock(); + lock.sharedLock(); try { byte[][] result = new byte[keys.length][]; for (int i = 0; i < keys.length; i++) { @@ -682,7 +682,7 @@ public byte[][] mget(byte[]... keys) { } return result; } finally { - lock.unlock(); + lock.sharedUnlock(); } } @@ -697,7 +697,7 @@ public int zadd(byte[] key, AbstractMap.SimpleEntry[] members, // Track both to eliminate conditional branch int added = 0, changed = 0; - lock.lock(); + lock.exclusiveLock(); try { final BytesKey mapKey = new BytesKey(key); byte[] value = map.get(mapKey).getBytes(); @@ -728,13 +728,13 @@ public int zadd(byte[] key, AbstractMap.SimpleEntry[] members, map.put(mapKey, new BytesValue(SerializationUtils.serialize(zSet))); return ch ? changed : added; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public Double zincrby(byte[] key, Double incr, BytesKey e, int flags) { - lock.lock(); + lock.exclusiveLock(); try { final BytesKey mapKey = new BytesKey(key); byte[] value = map.get(mapKey).getBytes(); @@ -765,14 +765,14 @@ public Double zincrby(byte[] key, Double incr, BytesKey e, int flags) { map.put(mapKey, new BytesValue(SerializationUtils.serialize(zSet))); return currentScore; } finally { - lock.unlock(); + lock.exclusiveUnlock(); } } @Override public Double zscore(byte[] key, byte[] member) { - lock.lock(); + lock.sharedLock(); try { byte[] value = map.get(new BytesKey(key)).getBytes(); if (value == null) { @@ -781,7 +781,7 @@ public Double zscore(byte[] key, byte[] member) { ZSet zset = (ZSet) SerializationUtils.deserialize(value); return zset.getScore(new BytesKey(member)); } finally { - lock.unlock(); + lock.sharedUnlock(); } } } diff --git a/store/src/main/java/dev/keva/store/lock/SpinLock.java b/store/src/main/java/dev/keva/store/lock/SpinLock.java index a3254b28..c8079669 100644 --- a/store/src/main/java/dev/keva/store/lock/SpinLock.java +++ b/store/src/main/java/dev/keva/store/lock/SpinLock.java @@ -1,18 +1,30 @@ package dev.keva.store.lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; -public class SpinLock extends ReentrantLock { +public class SpinLock extends ReentrantReadWriteLock { + private final ReentrantReadWriteLock.ReadLock readLock; + private final ReentrantReadWriteLock.WriteLock writeLock; public SpinLock() { super(true); + readLock = readLock(); + writeLock = writeLock(); } - public void lock() { - while (!tryLock()) { - } + public void sharedLock() { + while (!readLock.tryLock()); } - public void unlock() { - super.unlock(); + public void sharedUnlock() { + readLock.unlock(); + } + + public void exclusiveLock() { + while (!writeLock.tryLock()); + } + + public void exclusiveUnlock() { + writeLock.unlock(); } }