diff --git a/src/server/bitops_family.cc b/src/server/bitops_family.cc index c550d8cdd495..9a58e82b8c7c 100644 --- a/src/server/bitops_family.cc +++ b/src/server/bitops_family.cc @@ -728,9 +728,17 @@ ResultType Get::ApplyTo(Overflow ov, const string* bitfield) { const size_t offset = attr_.offset; auto last_byte_offset = GetByteIndex(attr_.offset + attr_.encoding_bit_size - 1); + if (GetByteIndex(offset) >= total_bytes) { + return 0; + } + + const string* result_str = bitfield; + string buff; uint32_t lsb = attr_.offset + attr_.encoding_bit_size - 1; - if (last_byte_offset > total_bytes) { - return {}; + if (last_byte_offset >= total_bytes) { + buff = *bitfield; + buff.resize(last_byte_offset + 1, 0); + result_str = &buff; } const bool is_negative = @@ -738,7 +746,7 @@ ResultType Get::ApplyTo(Overflow ov, const string* bitfield) { int64_t result = 0; for (size_t i = 0; i < attr_.encoding_bit_size; ++i) { - uint8_t byte{GetByteValue(bytes, lsb)}; + uint8_t byte{GetByteValue(*result_str, lsb)}; int32_t index = GetNormalizedBitIndex(lsb); int64_t old_bit = CheckBitStatus(byte, index); result |= old_bit << i; @@ -830,10 +838,11 @@ ResultType IncrBy::ApplyTo(Overflow ov, string* bitfield) { string& bytes = *bitfield; Get get(attr_); auto res = get.ApplyTo(ov, &bytes); + const int32_t total_bytes = static_cast(bytes.size()); + auto last_byte_offset = GetByteIndex(attr_.offset + attr_.encoding_bit_size - 1); - if (!res) { - Set set(attr_, incr_value_); - return set.ApplyTo(ov, &bytes); + if (last_byte_offset >= total_bytes) { + bytes.resize(last_byte_offset + 1, 0); } if (!HandleOverflow(ov, &*res)) { diff --git a/src/server/bitops_family_test.cc b/src/server/bitops_family_test.cc index 17bcca6a0161..82ec8b209709 100644 --- a/src/server/bitops_family_test.cc +++ b/src/server/bitops_family_test.cc @@ -805,4 +805,21 @@ TEST_F(BitOpsFamilyTest, BitFieldOperations) { ASSERT_THAT(Run({"bitfield", "foo", "get", "u1", "15"}), IntArg(1)); } +TEST_F(BitOpsFamilyTest, BitFieldLargeOffset) { + Run({"set", "foo", "bar"}); + + auto resp = Run({"bitfield", "foo", "get", "u32", "0", "overflow", "fail", "incrby", "u32", "0", + "4294967295"}); + EXPECT_THAT(resp, RespArray(ElementsAre(IntArg(1650553344), ArgType(RespExpr::NIL)))); + + resp = Run({"strlen", "foo"}); + EXPECT_THAT(resp, 4); + + resp = Run({"get", "foo"}); + EXPECT_THAT(ToSV(resp.GetBuf()), Eq(std::string_view("bar\0", 4))); + + resp = Run({"bitfield", "foo", "get", "u32", "4294967295"}); + EXPECT_THAT(resp, 0); +} + } // end of namespace dfly