Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/google/protobuf/compiler/cpp/message.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5295,7 +5295,7 @@ void MessageGenerator::GenerateByteSize(io::Printer* p) {
$WeakDescriptorSelfPin$;
$annotate_bytesize$;
// @@protoc_insertion_point(message_set_byte_size_start:$full_name$)
::size_t total_size = this_.$extensions$.MessageSetByteSize();
::size_t total_size = this_.$extensions$.MessageSetByteSize(&this_);
if (this_.$have_unknown_fields$) {
total_size += ::_pbi::ComputeUnknownMessageSetItemsSize(
this_.$unknown_fields$);
Expand Down Expand Up @@ -5353,7 +5353,7 @@ void MessageGenerator::GenerateByteSize(io::Printer* p) {
[&] {
if (descriptor_->extension_range_count() == 0) return;
p->Emit(R"cc(
total_size += this_.$extensions$.ByteSize();
total_size += this_.$extensions$.ByteSize(&this_);
)cc");
}},
{"prefetch",
Expand Down
24 changes: 12 additions & 12 deletions src/google/protobuf/descriptor.pb.cc

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

105 changes: 76 additions & 29 deletions src/google/protobuf/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <cstddef>
#include <cstdint>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <variant>
Expand Down Expand Up @@ -375,11 +374,18 @@ void* ExtensionSet::MutableRawRepeatedField(int number) {
// -------------------------------------------------------------------
// Enums

size_t ExtensionSet::GetMessageByteSizeLong(int number) const {
size_t ExtensionSet::GetMessageByteSizeLong(const MessageLite* extendee,
int number) const {
const Extension* extension = FindOrNull(number);
ABSL_CHECK(extension != nullptr) << "not present";
ABSL_DCHECK_TYPE(*extension, OPTIONAL_FIELD, MESSAGE);
return extension->is_lazy ? extension->ptr.lazymessage_value->ByteSizeLong()
ABSL_DCHECK(!extension->is_lazy ||
GetOrFindPrototypeForLazyMessage(*extension, extendee, number) !=
nullptr);
return extension->is_lazy ? extension->ptr.lazymessage_value->ByteSizeLong(
GetOrFindPrototypeForLazyMessage(
*extension, extendee, number),
arena_)
: extension->ptr.message_value->ByteSizeLong();
}

Expand Down Expand Up @@ -946,8 +952,8 @@ void ExtensionSet::InternalExtensionMergeFrom(Arena* arena,
Arena* other_arena) {
DebugAssertArenaMatches(arena);
Extension* dst_extension;
bool is_new = MaybeNewExtension(arena, number, other_extension.descriptor,
&dst_extension);
bool is_new = MaybeNewExtension(
arena, number, other_extension.descriptor_or_prototype, &dst_extension);
if (is_new) {
InternalExtensionMergeFromIntoUninitializedExtension(
arena, *dst_extension, extendee, number, other_extension, other_arena);
Expand Down Expand Up @@ -1009,9 +1015,12 @@ void ExtensionSet::InternalExtensionMergeFrom(Arena* arena,
ABSL_DCHECK(!dst_extension->is_repeated);
if (other_extension.is_lazy) {
if (dst_extension->is_lazy) {
const MessageLite* prototype = GetOrFindPrototypeForLazyMessage(
other_extension, extendee, number);
ABSL_DCHECK_NE(prototype, nullptr);
dst_extension->ptr.lazymessage_value->MergeFrom(
GetPrototypeForLazyMessage(extendee, number),
*other_extension.ptr.lazymessage_value, arena, other_arena);
prototype, *other_extension.ptr.lazymessage_value, arena,
other_arena);
} else {
dst_extension->ptr.message_value->CheckTypeAndMergeFrom(
other_extension.ptr.lazymessage_value->GetMessage(
Expand Down Expand Up @@ -1236,11 +1245,11 @@ uint8_t* ExtensionSet::InternalSerializeMessageSetWithCachedSizesToArray(
return target;
}

size_t ExtensionSet::ByteSize() const {
size_t ExtensionSet::ByteSize(const MessageLite* extendee) const {
size_t total_size = 0;
ForEach(
[&total_size](int number, const Extension& ext) {
total_size += ext.ByteSize(number);
[&](int number, const Extension& ext) {
total_size += ext.ByteSize(extendee, number, arena_);
},
Prefetch{});
return total_size;
Expand All @@ -1250,12 +1259,15 @@ size_t ExtensionSet::ByteSize() const {
// Defined in extension_set_heavy.cc.
// int ExtensionSet::SpaceUsedExcludingSelf() const

bool ExtensionSet::MaybeNewExtension(Arena* arena, int number,
const FieldDescriptor* descriptor,
Extension** result) {
bool extension_is_new = false;
std::tie(*result, extension_is_new) = Insert(arena, number);
(*result)->descriptor = descriptor;
bool ExtensionSet::MaybeNewExtension(
Arena* arena, int number,
Extension::DescriptorOrPrototype descriptor_or_prototype,
Extension** result_ptr) {
auto [result, extension_is_new] = Insert(arena, number);
*result_ptr = result;
if (extension_is_new) {
result->descriptor_or_prototype = descriptor_or_prototype;
}
return extension_is_new;
}

Expand Down Expand Up @@ -1327,7 +1339,8 @@ void ExtensionSet::Extension::Clear() {
}
}

size_t ExtensionSet::Extension::ByteSize(int number) const {
size_t ExtensionSet::Extension::ByteSize(const MessageLite* extendee,
int number, Arena* arena) const {
size_t result = 0;

if (is_repeated) {
Expand Down Expand Up @@ -1443,7 +1456,10 @@ size_t ExtensionSet::Extension::ByteSize(int number) const {
#undef HANDLE_TYPE
case WireFormatLite::TYPE_MESSAGE: {
result += WireFormatLite::LengthDelimitedSize(
is_lazy ? ptr.lazymessage_value->ByteSizeLong()
is_lazy ? ptr.lazymessage_value->ByteSizeLong(
ExtensionSet::GetOrFindPrototypeForLazyMessage(
*this, extendee, number),
arena)
: ptr.message_value->ByteSizeLong());
break;
}
Expand Down Expand Up @@ -1554,7 +1570,7 @@ bool ExtensionSet::Extension::IsInitialized(const ExtensionSet* ext_set,
if (!is_lazy) return ptr.message_value->IsInitialized();

const MessageLite* prototype =
ext_set->GetPrototypeForLazyMessage(extendee, number);
ext_set->GetOrFindPrototypeForLazyMessage(*this, extendee, number);
ABSL_DCHECK_NE(prototype, nullptr)
<< "extendee: " << extendee->GetTypeName() << "; number: " << number;
return ptr.lazymessage_value->IsInitialized(prototype, arena);
Expand Down Expand Up @@ -1857,7 +1873,9 @@ uint8_t* ExtensionSet::Extension::InternalSerializeFieldWithCachedSizesToArray(
case WireFormatLite::TYPE_MESSAGE:
if (is_lazy) {
const auto* prototype =
extension_set->GetPrototypeForLazyMessage(extendee, number);
extension_set->GetOrFindPrototypeForLazyMessage(*this, extendee,
number);
ABSL_DCHECK_NE(prototype, nullptr);
target = ptr.lazymessage_value->WriteMessageToArray(prototype, number,
target, stream);
} else {
Expand All @@ -1871,7 +1889,8 @@ uint8_t* ExtensionSet::Extension::InternalSerializeFieldWithCachedSizesToArray(
return target;
}

const MessageLite* ExtensionSet::GetPrototypeForLazyMessage(
template <bool mustBeGenerated>
const MessageLite* ExtensionSet::FindPrototypeForLazyMessageImpl(
const MessageLite* extendee, int number) {
GeneratedExtensionFinder finder(extendee);
bool was_packed_on_wire = false;
Expand All @@ -1881,9 +1900,32 @@ const MessageLite* ExtensionSet::GetPrototypeForLazyMessage(
&extension_info, &was_packed_on_wire)) {
return nullptr;
}
if constexpr (mustBeGenerated) {
if (extension_info.lazy_eager_verify_func == nullptr) {
return nullptr;
}
}
return extension_info.message_info.prototype;
}

const MessageLite* ExtensionSet::FindPrototypeForLazyMessage(
const MessageLite* extendee, int number) {
return FindPrototypeForLazyMessageImpl</*mustBeGenerated=*/false>(extendee,
number);
}

const MessageLite*
ExtensionSet::FindPrototypeFromGeneratedFactoryForLazyMessage(
const MessageLite* extendee, int number) {
return FindPrototypeForLazyMessageImpl</*mustBeGenerated=*/true>(extendee,
number);
}

template const MessageLite* ExtensionSet::FindPrototypeForLazyMessageImpl<true>(
const MessageLite*, int);
template const MessageLite*
ExtensionSet::FindPrototypeForLazyMessageImpl<false>(const MessageLite*, int);

uint8_t*
ExtensionSet::Extension::InternalSerializeMessageSetItemWithCachedSizesToArray(
const MessageLite* extendee, const ExtensionSet* extension_set, int number,
Expand All @@ -1906,8 +1948,9 @@ ExtensionSet::Extension::InternalSerializeMessageSetItemWithCachedSizesToArray(
WireFormatLite::kMessageSetTypeIdNumber, number, target);
// Write message.
if (is_lazy) {
const auto* prototype =
extension_set->GetPrototypeForLazyMessage(extendee, number);
const auto* prototype = extension_set->GetOrFindPrototypeForLazyMessage(
*this, extendee, number);
ABSL_DCHECK_NE(prototype, nullptr);
target = ptr.lazymessage_value->WriteMessageToArray(
prototype, WireFormatLite::kMessageSetMessageNumber, target, stream);
} else {
Expand All @@ -1922,11 +1965,12 @@ ExtensionSet::Extension::InternalSerializeMessageSetItemWithCachedSizesToArray(
return target;
}

size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
size_t ExtensionSet::Extension::MessageSetItemByteSize(
const MessageLite* extendee, int number, Arena* arena) const {
if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
// Not a valid MessageSet extension, but compute the byte size for it the
// normal way.
return ByteSize(number);
return ByteSize(extendee, number, arena);
}

if (is_cleared) return 0;
Expand All @@ -1938,17 +1982,20 @@ size_t ExtensionSet::Extension::MessageSetItemByteSize(int number) const {

// message
our_size += WireFormatLite::LengthDelimitedSize(
is_lazy ? ptr.lazymessage_value->ByteSizeLong()
is_lazy ? ptr.lazymessage_value->ByteSizeLong(
ExtensionSet::GetOrFindPrototypeForLazyMessage(
*this, extendee, number),
arena)
: ptr.message_value->ByteSizeLong());

return our_size;
}

size_t ExtensionSet::MessageSetByteSize() const {
size_t ExtensionSet::MessageSetByteSize(const MessageLite* extendee) const {
size_t total_size = 0;
ForEach(
[&total_size](int number, const Extension& ext) {
total_size += ext.MessageSetItemByteSize(number);
[&](int number, const Extension& ext) {
total_size += ext.MessageSetItemByteSize(extendee, number, arena_);
},
Prefetch{});
return total_size;
Expand Down
Loading
Loading