diff --git a/ydb/public/lib/ydb_cli/commands/ydb_service_import.cpp b/ydb/public/lib/ydb_cli/commands/ydb_service_import.cpp index e83f89e11c66..509a7ccc7302 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_service_import.cpp +++ b/ydb/public/lib/ydb_cli/commands/ydb_service_import.cpp @@ -6,7 +6,6 @@ #include #include #include -#include #include #include @@ -302,6 +301,17 @@ void TCommandImportFromCsv::Config(TConfig& config) { config.Opts->AddLongOption("newline-delimited", "No newline characters inside records, enables some import optimizations (see docs)") .StoreTrue(&NewlineDelimited); + TStringStream description; + description << "Format that data will be serialized to before sending to YDB. Available options: "; + NColorizer::TColors colors = NColorizer::AutoColors(Cout); + description << "\n " << colors.BoldColor() << "tvalue" << colors.OldColor() + << "\n " << "A default YDB protobuf format"; + description << "\n " << colors.BoldColor() << "arrow" << colors.OldColor() + << "\n " << "Apache Arrow format"; + description << "\nDefault: " << colors.CyanColor() << "\"" << "tvalue" << "\"" << colors.OldColor() << "."; + config.Opts->AddLongOption("send-format", description.Str()) + .RequiredArgument("STRING").StoreResult(&SendFormat) + .Hidden(); if (InputFormat == EDataFormat::Csv) { config.Opts->AddLongOption("delimiter", "Field delimiter in rows") .RequiredArgument("STRING").StoreResult(&Delimiter).DefaultValue(Delimiter); @@ -325,6 +335,7 @@ int TCommandImportFromCsv::Run(TConfig& config) { settings.HeaderRow(HeaderRow); settings.NullValue(NullValue); settings.Verbose(config.IsVerbose()); + settings.SendFormat(SendFormat); if (Delimiter.size() != 1) { throw TMisuseException() diff --git a/ydb/public/lib/ydb_cli/commands/ydb_service_import.h b/ydb/public/lib/ydb_cli/commands/ydb_service_import.h index c6eb4f57e6c4..dd1f85fa18b1 100644 --- a/ydb/public/lib/ydb_cli/commands/ydb_service_import.h +++ b/ydb/public/lib/ydb_cli/commands/ydb_service_import.h @@ -7,6 +7,7 @@ #include #include #include +#include namespace NYdb::NConsoleClient { @@ -86,6 +87,7 @@ class TCommandImportFromCsv : public TCommandImportFileBase { ui32 SkipRows = 0; bool Header = false; bool NewlineDelimited = true; + NConsoleClient::ESendFormat SendFormat = NConsoleClient::ESendFormat::Default; }; class TCommandImportFromTsv : public TCommandImportFromCsv { diff --git a/ydb/public/lib/ydb_cli/common/csv_parser.cpp b/ydb/public/lib/ydb_cli/common/csv_parser.cpp index 76d0dce81e3f..6da0db843cf4 100644 --- a/ydb/public/lib/ydb_cli/common/csv_parser.cpp +++ b/ydb/public/lib/ydb_cli/common/csv_parser.cpp @@ -15,9 +15,7 @@ class TCsvToYdbConverter { public: explicit TCsvToYdbConverter(TTypeParser& parser, const std::optional& nullValue) : Parser(parser) - , NullValue(nullValue) - { - } + , NullValue(nullValue) {} template && std::is_signed_v, std::nullptr_t> = nullptr> static i64 StringToArithmetic(const TString& token, size_t& cnt) { @@ -165,7 +163,7 @@ class TCsvToYdbConverter { } case EPrimitiveType::Interval64: Builder.Interval64(GetArithmetic(token)); - break; + break; case EPrimitiveType::TzDate: Builder.TzDate(token); break; @@ -441,7 +439,7 @@ TStringBuf Consume(NCsvFormat::CsvSplitter& splitter, TCsvParser::TCsvParser(TString&& headerRow, const char delimeter, const std::optional& nullValue, const std::map* destinationTypes, - const std::map* paramSources) + const std::map* paramSources) : HeaderRow(std::move(headerRow)) , Delimeter(delimeter) , NullValue(nullValue) @@ -454,7 +452,7 @@ TCsvParser::TCsvParser(TString&& headerRow, const char delimeter, const std::opt TCsvParser::TCsvParser(TVector&& header, const char delimeter, const std::optional& nullValue, const std::map* destinationTypes, - const std::map* paramSources) + const std::map* paramSources) : Header(std::move(header)) , Delimeter(delimeter) , NullValue(nullValue) @@ -529,41 +527,91 @@ void TCsvParser::BuildValue(TString& data, TValueBuilder& builder, const TType& builder.EndStruct(); } -TValue TCsvParser::BuildList(std::vector& lines, const TString& filename, std::optional row) const { +TValue TCsvParser::BuildList(const std::vector& lines, const TString& filename, std::optional row) const { std::vector> columnTypeParsers; columnTypeParsers.reserve(ResultColumnCount); for (const TType* type : ResultLineTypesSorted) { columnTypeParsers.push_back(std::make_unique(*type)); } - - Ydb::Value listValue; - auto* listItems = listValue.mutable_items(); + + // Original path with local value object + Ydb::Value listProtoValue; + auto* listItems = listProtoValue.mutable_items(); listItems->Reserve(lines.size()); - for (auto& line : lines) { - NCsvFormat::CsvSplitter splitter(line, Delimeter); - TParseMetadata meta {row, filename}; - auto* structItems = listItems->Add()->mutable_items(); - structItems->Reserve(ResultColumnCount); - auto headerIt = Header.cbegin(); - auto skipIt = SkipBitMap.begin(); - auto typeParserIt = columnTypeParsers.begin(); - do { - if (headerIt == Header.cend()) { // SkipBitMap has same size as Header - throw FormatError(yexception() << "Header contains less fields than data. Header: \"" << HeaderRow << "\", data: \"" << line << "\"", meta); - } - TStringBuf nextField = Consume(splitter, meta, *headerIt); - if (!*skipIt) { - *structItems->Add() = FieldToValue(*typeParserIt->get(), nextField, NullValue, meta, *headerIt).GetProto(); - ++typeParserIt; - } - ++headerIt; - ++skipIt; - } while (splitter.Step()); + + for (const auto& line : lines) { + ProcessCsvLine(line, listItems, columnTypeParsers, row, filename); if (row.has_value()) { ++row.value(); } } - return TValue(ResultListType.value(), std::move(listValue)); + + // Return a TValue that takes ownership via move + return TValue(ResultListType.value(), std::move(listProtoValue)); +} + +TValue TCsvParser::BuildListOnArena( + const std::vector& lines, + const TString& filename, + google::protobuf::Arena* arena, + std::optional row +) const { + Y_ASSERT(arena != nullptr); + + std::vector> columnTypeParsers; + columnTypeParsers.reserve(ResultColumnCount); + for (const TType* type : ResultLineTypesSorted) { + columnTypeParsers.push_back(std::make_unique(*type)); + } + + // allocate Ydb::Value on arena + Ydb::Value* listProtoValue = google::protobuf::Arena::CreateMessage(arena); + auto* listItems = listProtoValue->mutable_items(); + listItems->Reserve(lines.size()); + + for (const auto& line : lines) { + ProcessCsvLine(line, listItems, columnTypeParsers, row, filename); + if (row.has_value()) { + ++row.value(); + } + } + + // Return a TValue that references the arena-allocated message + return TValue(ResultListType.value(), listProtoValue); +} + +// Helper method to process a single CSV line +void TCsvParser::ProcessCsvLine( + const TString& line, + google::protobuf::RepeatedPtrField* listItems, + const std::vector>& columnTypeParsers, + std::optional currentRow, + const TString& filename +) const { + NCsvFormat::CsvSplitter splitter(line, Delimeter); + auto* structItems = listItems->Add()->mutable_items(); + structItems->Reserve(ResultColumnCount); + + const TParseMetadata meta {currentRow, filename}; + + auto headerIt = Header.cbegin(); + auto skipIt = SkipBitMap.begin(); + auto typeParserIt = columnTypeParsers.begin(); + + do { + if (headerIt == Header.cend()) { // SkipBitMap has same size as Header + throw FormatError(yexception() << "Header contains less fields than data. Header: \"" << HeaderRow << "\", data: \"" << line << "\"", meta); + } + TStringBuf nextField = Consume(splitter, meta, *headerIt); + if (!*skipIt) { + TValue builtValue = FieldToValue(*typeParserIt->get(), nextField, NullValue, meta, *headerIt); + *structItems->Add() = std::move(builtValue.GetProto()); + + ++typeParserIt; + } + ++headerIt; + ++skipIt; + } while (splitter.Step()); } void TCsvParser::BuildLineType() { @@ -607,5 +655,10 @@ const TVector& TCsvParser::GetHeader() { return Header; } +const TString& TCsvParser::GetHeaderRow() const { + return HeaderRow; +} + } } + diff --git a/ydb/public/lib/ydb_cli/common/csv_parser.h b/ydb/public/lib/ydb_cli/common/csv_parser.h index 05c3def83d9c..7cad2fcb5d09 100644 --- a/ydb/public/lib/ydb_cli/common/csv_parser.h +++ b/ydb/public/lib/ydb_cli/common/csv_parser.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include @@ -35,10 +36,20 @@ class TCsvParser { void BuildParams(TString& data, TParamsBuilder& builder, const TParseMetadata& meta) const; void BuildValue(TString& data, TValueBuilder& builder, const TType& type, const TParseMetadata& meta) const; - TValue BuildList(std::vector& lines, const TString& filename, + + TValue BuildList(const std::vector& lines, const TString& filename, std::optional row = std::nullopt) const; + + TValue BuildListOnArena( + const std::vector& lines, + const TString& filename, + google::protobuf::Arena* arena, + std::optional row = std::nullopt + ) const; + void BuildLineType(); const TVector& GetHeader(); + const TString& GetHeaderRow() const; private: TVector Header; @@ -60,6 +71,15 @@ class TCsvParser { // Types of columns in a single row in resulting TValue. // Column order according to the header, though can have less elements than the Header std::vector ResultLineTypesSorted; + + // Helper method to process a single line of CSV data + void ProcessCsvLine( + const TString& line, + google::protobuf::RepeatedPtrField* listItems, + const std::vector>& columnTypeParsers, + std::optional currentRow, + const TString& filename + ) const; }; } diff --git a/ydb/public/lib/ydb_cli/common/csv_parser_ut.cpp b/ydb/public/lib/ydb_cli/common/csv_parser_ut.cpp index eccfacbf92ce..7d2f1e1593a6 100644 --- a/ydb/public/lib/ydb_cli/common/csv_parser_ut.cpp +++ b/ydb/public/lib/ydb_cli/common/csv_parser_ut.cpp @@ -317,7 +317,7 @@ Y_UNIT_TEST_SUITE(YdbCliCsvParserTests) { {"col2", TTypeBuilder().BeginOptional().Primitive(EPrimitiveType::Int64).EndOptional().Build()}, {"col3", TTypeBuilder().Primitive(EPrimitiveType::Bool).Build()}, }; - + TString csvHeader = "col4,col3,col5,col1,col6"; std::vector data = { "col4 unused value,true,col5 unused value,col1 value,col6 unused value" diff --git a/ydb/public/lib/ydb_cli/import/import.cpp b/ydb/public/lib/ydb_cli/import/import.cpp index beb067c7e13f..5a798618160d 100644 --- a/ydb/public/lib/ydb_cli/import/import.cpp +++ b/ydb/public/lib/ydb_cli/import/import.cpp @@ -11,6 +11,8 @@ #include #include +#include + #include #include #include @@ -25,6 +27,7 @@ #include #include #include +#include #include #include @@ -50,6 +53,7 @@ #include #endif +#include namespace NYdb { namespace NConsoleClient { @@ -555,6 +559,16 @@ class TImportFileClient::TImpl { std::shared_ptr progressFile); TAsyncStatus UpsertTValueBuffer(const TString& dbPath, TValueBuilder& builder); TAsyncStatus UpsertTValueBuffer(const TString& dbPath, std::function&& buildFunc); + + TAsyncStatus UpsertTValueBufferParquet( + const TString& dbPath, + std::shared_ptr batch, + const arrow::ipc::IpcWriteOptions& writeOptions + ); + + TAsyncStatus UpsertTValueBufferOnArena( + const TString& dbPath, std::function&& buildFunc); + TStatus UpsertJson(IInputStream &input, const TString &dbPath, std::optional inputSizeHint, ProgressCallbackFunc & progressCallback); TStatus UpsertParquet(const TString& filename, const TString& dbPath, ProgressCallbackFunc & progressCallback); @@ -672,7 +686,7 @@ TStatus TImportFileClient::TImpl::Import(const TVector& filePaths, cons auto start = TInstant::Now(); - TThreadPool jobPool; + TThreadPool jobPool(IThreadPool::TParams().SetThreadNamePrefix("FileWorker")); jobPool.Start(filePathsSize); TVector> asyncResults; @@ -925,6 +939,99 @@ TAsyncStatus TImportFileClient::TImpl::UpsertTValueBuffer(const TString& dbPath, }); } +inline TAsyncStatus TImportFileClient::TImpl::UpsertTValueBufferParquet( + const TString& dbPath, + std::shared_ptr batch, + const arrow::ipc::IpcWriteOptions& writeOptions +) { + if (!RequestsInflight->try_acquire()) { + if (Settings.Verbose_ && Settings.NewlineDelimited_) { + if (!InformedAboutLimit.exchange(true)) { + Cerr << (TStringBuilder() << "@ (each '@' means max request inflight is reached and a worker thread is waiting for " + "any response from database)" << Endl); + } else { + Cerr << '@'; + } + } + RequestsInflight->acquire(); + } + + auto retryFunc = [parquet = NYdb_cli::NArrow::SerializeBatch(batch, writeOptions), + schema = NYdb_cli::NArrow::SerializeSchema(*batch->schema()), + dbPath](NTable::TTableClient& client) { + return client.BulkUpsert(dbPath, NTable::EDataFormat::ApacheArrow, parquet, schema) + .Apply([](const NTable::TAsyncBulkUpsertResult& result) { + return TStatus(result.GetValueSync()); + }); + }; + + return TableClient->RetryOperation(std::move(retryFunc), RetrySettings) + .Apply([this](const TAsyncStatus& asyncStatus) { + NYdb::TStatus status = asyncStatus.GetValueSync(); + if (!status.IsSuccess()) { + if (!Failed.exchange(true)) { + ErrorStatus = MakeHolder(status); + } + } + RequestsInflight->release(); + return asyncStatus; + }); +} + +inline TAsyncStatus TImportFileClient::TImpl::UpsertTValueBufferOnArena( + const TString& dbPath, std::function&& buildFunc) { + auto arena = std::make_shared(); + + // For the first attempt values are built before acquiring request inflight semaphore + std::optional prebuiltValue = buildFunc(arena.get()); + + auto retryFunc = [this, &dbPath, buildFunc = std::move(buildFunc), + prebuiltValue = std::move(prebuiltValue), arena = std::move(arena)] + (NYdb::NTable::TTableClient& tableClient) mutable -> TAsyncStatus { + auto buildTValueAndSendRequest = [this, &buildFunc, &dbPath, &tableClient, &prebuiltValue, arena]() { + // For every retry attempt after first request build value from strings again + // to prevent copying data in retryFunc in a happy way when there is only one request + TValue builtValue = prebuiltValue.has_value() ? std::move(prebuiltValue.value()) : buildFunc(arena.get()); + prebuiltValue = std::nullopt; + + auto settings = UpsertSettings; + settings.Arena(arena.get()); + return tableClient.BulkUpsert( + dbPath, std::move(builtValue), settings) + .Apply([](const NYdb::NTable::TAsyncBulkUpsertResult& bulkUpsertResult) { + NYdb::TStatus status = bulkUpsertResult.GetValueSync(); + return NThreading::MakeFuture(status); + }); + }; + // Running heavy building task on processing pool: + return NThreading::Async(std::move(buildTValueAndSendRequest), *ProcessingPool); + }; + + if (!RequestsInflight->try_acquire()) { + if (Settings.Verbose_ && Settings.NewlineDelimited_) { + if (!InformedAboutLimit.exchange(true)) { + Cerr << (TStringBuilder() << "@ (each '@' means max request inflight is reached and a worker thread is waiting for " + "any response from database)" << Endl); + } else { + Cerr << '@'; + } + } + RequestsInflight->acquire(); + } + + return TableClient->RetryOperation(std::move(retryFunc), RetrySettings) + .Apply([this](const TAsyncStatus& asyncStatus) { + NYdb::TStatus status = asyncStatus.GetValueSync(); + if (!status.IsSuccess()) { + if (!Failed.exchange(true)) { + ErrorStatus = MakeHolder(status); + } + } + RequestsInflight->release(); + return asyncStatus; + }); +} + TStatus TImportFileClient::TImpl::UpsertCsv(IInputStream& input, const TString& dbPath, const TString& filePath, @@ -986,30 +1093,108 @@ TStatus TImportFileClient::TImpl::UpsertCsv(IInputStream& input, } }; + // Note: table = dbPath (path to the table on the server) + auto columns = DbTableInfo->GetTableColumns(); + + const Ydb::Formats::CsvSettings csvSettings = ([this]() { + Ydb::Formats::CsvSettings settings; + settings.set_delimiter(Settings.Delimiter_); + settings.set_header(Settings.Header_); + if (Settings.NullValue_.has_value()) { + settings.set_null_value(Settings.NullValue_.value()); + } + settings.set_skip_rows(Settings.SkipRows_); + return settings; + }()); + + auto writeOptions = arrow::ipc::IpcWriteOptions::Defaults(); + constexpr auto codecType = arrow::Compression::type::ZSTD; + writeOptions.codec = *arrow::util::Codec::Create(codecType); + auto upsertCsvFunc = [&](std::vector&& buffer, ui64 row, std::shared_ptr batchStatus) { - auto buildFunc = [&, buffer = std::move(buffer), row, this] () mutable { - try { - return parser.BuildList(buffer, filePath, row); - } catch (const std::exception& e) { - if (!Failed.exchange(true)) { - ErrorStatus = MakeHolder(MakeStatus(EStatus::INTERNAL_ERROR, e.what())); + switch (Settings.SendFormat_) { + case ESendFormat::Default: + case ESendFormat::TValue: + { + auto buildOnArenaFunc = [&, buffer = std::move(buffer), row, this] (google::protobuf::Arena* arena) mutable { + try { + return parser.BuildListOnArena(buffer, filePath, arena, row); + } catch (const std::exception& e) { + if (!Failed.exchange(true)) { + ErrorStatus = MakeHolder(MakeStatus(EStatus::INTERNAL_ERROR, e.what())); + } + jobInflightManager->ReleaseJob(); + throw; + } + }; + + UpsertTValueBufferOnArena(dbPath, std::move(buildOnArenaFunc)) + .Apply([&, batchStatus](const TAsyncStatus& asyncStatus) { + jobInflightManager->ReleaseJob(); + if (asyncStatus.GetValueSync().IsSuccess()) { + batchStatus->Completed = true; + if (!FileProgressPool->AddFunc(saveProgressIfAny) && !Failed.exchange(true)) { + ErrorStatus = MakeHolder(MakeStatus(EStatus::INTERNAL_ERROR, + "Couldn't add worker func to save progress")); + } + } + return asyncStatus; + }); } - jobInflightManager->ReleaseJob(); - throw; - } - }; - UpsertTValueBuffer(dbPath, std::move(buildFunc)) - .Apply([&, batchStatus](const TAsyncStatus& asyncStatus) { - jobInflightManager->ReleaseJob(); - if (asyncStatus.GetValueSync().IsSuccess()) { - batchStatus->Completed = true; - if (!FileProgressPool->AddFunc(saveProgressIfAny) && !Failed.exchange(true)) { - ErrorStatus = MakeHolder(MakeStatus(EStatus::INTERNAL_ERROR, - "Couldn't add worker func to save progress")); + break; + case ESendFormat::ApacheArrow: + { + const i64 estimatedCsvLineLength = (!buffer.empty() ? 2 * buffer.front().size() : 10'000); + TStringBuilder data; + data.reserve((buffer.size() + (Settings.Header_ ? 1 : 0)) * estimatedCsvLineLength); + // insert header if it is present in the given csv file + if (Settings.Header_) { + data << parser.GetHeaderRow() << Endl; + } + data << JoinSeq("\n", buffer) << Endl; + + // if header is present, it is expected to be the first line of the data + TString error; + auto arrowCsv = NKikimr::NFormats::TArrowCSVTable::Create(columns, Settings.Header_); + if (arrowCsv.ok()) { + if (auto batch = arrowCsv->ReadSingleBatch(data, csvSettings, error)) { + if (!error) { + // batch was read successfully, sending data via Apache Arrow + UpsertTValueBufferParquet(dbPath, std::move(batch), writeOptions) + .Apply([&, batchStatus](const TAsyncStatus& asyncStatus) { + jobInflightManager->ReleaseJob(); + if (asyncStatus.GetValueSync().IsSuccess()) { + batchStatus->Completed = true; + if (!FileProgressPool->AddFunc(saveProgressIfAny) && !Failed.exchange(true)) { + ErrorStatus = MakeHolder(MakeStatus(EStatus::INTERNAL_ERROR, + "Couldn't add worker func to save progress")); + } + } + return asyncStatus; + }); + } else { + error = "Error while reading a batch from Apache Arrow: " + error; + } + } else { + error = "Could not read a batch from Apache Arrow"; + } + } else { + error = arrowCsv.status().ToString(); + } + + if (!error.empty()) { + if (!Failed.exchange(true)) { + ErrorStatus = MakeHolder(MakeStatus(EStatus::INTERNAL_ERROR, error)); + } } } - return asyncStatus; - }); + break; + default: + if (!Failed.exchange(true)) { + ErrorStatus = MakeHolder(MakeStatus(EStatus::INTERNAL_ERROR, + (TStringBuilder() << "Unknown send format: " << Settings.SendFormat_).c_str())); + } + } }; for (ui32 i = 0; i < rowsToSkip; ++i) { @@ -1037,7 +1222,7 @@ TStatus TImportFileClient::TImpl::UpsertCsv(IInputStream& input, line.erase(line.size() - Settings.Delimiter_.size()); } - buffer.push_back(line); + buffer.push_back(std::move(line)); if (readBytes >= nextReadBorder && Settings.Verbose_) { nextReadBorder += VerboseModeStepSize; diff --git a/ydb/public/lib/ydb_cli/import/import.h b/ydb/public/lib/ydb_cli/import/import.h index 814298dff73e..86b3df069881 100644 --- a/ydb/public/lib/ydb_cli/import/import.h +++ b/ydb/public/lib/ydb_cli/import/import.h @@ -35,6 +35,13 @@ class TImportClient; namespace NConsoleClient { +// EDataFormat to be used in operations related to structured data +enum class ESendFormat { + Default /* "default" */, + TValue /* "tvalue" */, + ApacheArrow /* "arrow" */ +}; + struct TImportFileSettings : public TOperationRequestSettings { using TSelf = TImportFileSettings; @@ -60,6 +67,7 @@ struct TImportFileSettings : public TOperationRequestSettings, NullValue, std::nullopt); FLUENT_SETTING_DEFAULT(bool, Verbose, false); + FLUENT_SETTING_DEFAULT(ESendFormat, SendFormat, ESendFormat::Default); }; class TImportFileClient { diff --git a/ydb/public/lib/ydb_cli/import/ya.make b/ydb/public/lib/ydb_cli/import/ya.make index b6b4ff8c1754..9e28d5f8ac87 100644 --- a/ydb/public/lib/ydb_cli/import/ya.make +++ b/ydb/public/lib/ydb_cli/import/ya.make @@ -14,4 +14,6 @@ PEERDIR( library/cpp/string_utils/csv ) +GENERATE_ENUM_SERIALIZATION(import.h) + END() diff --git a/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/table/table.h b/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/table/table.h index 01bccbc00c03..0e7aa8e81d6b 100644 --- a/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/table/table.h +++ b/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/table/table.h @@ -1162,6 +1162,8 @@ struct TBulkUpsertSettings : public TOperationRequestSettings { diff --git a/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/value/value.h b/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/value/value.h index a43f7f4aacba..9003d13c5dea 100644 --- a/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/value/value.h +++ b/ydb/public/sdk/cpp/include/ydb-cpp-sdk/client/value/value.h @@ -276,13 +276,19 @@ class TValue { public: TValue(const TType& type, const Ydb::Value& valueProto); TValue(const TType& type, Ydb::Value&& valueProto); + /** + * Lifetime of the arena, and hence the `Ydb::Value`, is expected to be managed by the caller. + * The `Ydb::Value` is expected to be arena-allocated. + * + * See: https://protobuf.dev/reference/cpp/arenas + */ + TValue(const TType& type, Ydb::Value* arenaAllocatedValueProto); const TType& GetType() const; - TType & GetType(); + TType& GetType(); const Ydb::Value& GetProto() const; Ydb::Value& GetProto(); - private: class TImpl; std::shared_ptr Impl_; diff --git a/ydb/public/sdk/cpp/src/client/impl/ydb_internal/grpc_connections/grpc_connections.h b/ydb/public/sdk/cpp/src/client/impl/ydb_internal/grpc_connections/grpc_connections.h index f1a23631c94e..d65d408bf022 100644 --- a/ydb/public/sdk/cpp/src/client/impl/ydb_internal/grpc_connections/grpc_connections.h +++ b/ydb/public/sdk/cpp/src/client/impl/ydb_internal/grpc_connections/grpc_connections.h @@ -140,9 +140,55 @@ class TGRpcConnectionsImpl TRequest, TResponse>::TAsyncRequest; + template + class TRequestWrapper { + public: + // Implicit conversion from rvalue reference + TRequestWrapper(TRequest&& request) + : Storage_(std::move(request)) + {} + + // Implicit conversion from pointer. Means that request is allocated on Arena + TRequestWrapper(TRequest* request) + : Storage_(request) + {} + + // Copy constructor + TRequestWrapper(const TRequestWrapper& other) = default; + + // Move constructor + TRequestWrapper(TRequestWrapper&& other) = default; + + // Copy assignment + TRequestWrapper& operator=(const TRequestWrapper& other) = default; + + // Move assignment + TRequestWrapper& operator=(TRequestWrapper&& other) = default; + + template + void DoRequest( + std::unique_ptr>& serviceConnection, + NYdbGrpc::TAdvancedResponseCallback&& responseCbLow, + typename NYdbGrpc::TSimpleRequestProcessor::TAsyncRequest rpc, + const TCallMeta& meta, + IQueueClientContext* context) + { + if (auto ptr = std::get_if(&Storage_)) { + serviceConnection->DoAdvancedRequest(**ptr, + std::move(responseCbLow), rpc, meta, context); + } else { + serviceConnection->DoAdvancedRequest(std::move(std::get(Storage_)), + std::move(responseCbLow), rpc, meta, context); + } + } + + private: + std::variant Storage_; + }; + template void Run( - TRequest&& request, + TRequestWrapper&& requestWrapper, TResponseCb&& userResponseCb, TSimpleRpc rpc, TDbDriverStatePtr dbState, @@ -174,7 +220,8 @@ class TGRpcConnectionsImpl } WithServiceConnection( - [this, request = std::move(request), userResponseCb = std::move(userResponseCb), rpc, requestSettings, context = std::move(context), dbState] + [this, requestWrapper = std::move(requestWrapper), userResponseCb = std::move(userResponseCb), rpc, + requestSettings, context = std::move(context), dbState] (TPlainStatus status, TConnection serviceConnection, TEndpointKey endpoint) mutable -> void { if (!status.Ok()) { userResponseCb( @@ -271,14 +318,13 @@ class TGRpcConnectionsImpl } }; - serviceConnection->DoAdvancedRequest(std::move(request), std::move(responseCbLow), rpc, meta, - context.get()); + requestWrapper.DoRequest(serviceConnection, std::move(responseCbLow), rpc, meta, context.get()); }, dbState, requestSettings.PreferredEndpoint, requestSettings.EndpointPolicy); } template void RunDeferred( - TRequest&& request, + TRequestWrapper&& requestWrapper, TDeferredOperationCb&& userResponseCb, TSimpleRpc rpc, TDbDriverStatePtr dbState, @@ -321,7 +367,7 @@ class TGRpcConnectionsImpl }; Run( - std::move(request), + std::move(requestWrapper), responseCb, rpc, dbState, @@ -357,7 +403,7 @@ class TGRpcConnectionsImpl template void RunDeferred( - TRequest&& request, + TRequestWrapper&& requestWrapper, TDeferredResultCb&& userResponseCb, TSimpleRpc rpc, TDbDriverStatePtr dbState, @@ -375,7 +421,7 @@ class TGRpcConnectionsImpl }; RunDeferred( - std::move(request), + std::move(requestWrapper), operationCb, rpc, dbState, diff --git a/ydb/public/sdk/cpp/src/client/impl/ydb_internal/make_request/make.h b/ydb/public/sdk/cpp/src/client/impl/ydb_internal/make_request/make.h index d8dd35dbe6b7..742e6a8a53e0 100644 --- a/ydb/public/sdk/cpp/src/client/impl/ydb_internal/make_request/make.h +++ b/ydb/public/sdk/cpp/src/client/impl/ydb_internal/make_request/make.h @@ -46,4 +46,18 @@ TProtoRequest MakeOperationRequest(const TRequestSettings& settings) { return request; } + +template +TProtoRequest* MakeRequestOnArena(google::protobuf::Arena* arena) { + return google::protobuf::Arena::CreateMessage(arena); +} + +template +TProtoRequest* MakeOperationRequestOnArena(const TRequestSettings& settings, google::protobuf::Arena* arena) { + Y_ASSERT(arena != nullptr); + auto request = MakeRequestOnArena(arena); + FillOperationParams(settings, *request); + return request; +} + } // namespace NYdb diff --git a/ydb/public/sdk/cpp/src/client/table/impl/table_client.cpp b/ydb/public/sdk/cpp/src/client/table/impl/table_client.cpp index 163c7a477ee8..160c3cf0b225 100644 --- a/ydb/public/sdk/cpp/src/client/table/impl/table_client.cpp +++ b/ydb/public/sdk/cpp/src/client/table/impl/table_client.cpp @@ -990,33 +990,49 @@ void TTableClient::TImpl::SetStatCollector(const NSdkStats::TStatCollector::TCli } TAsyncBulkUpsertResult TTableClient::TImpl::BulkUpsert(const std::string& table, TValue&& rows, const TBulkUpsertSettings& settings, bool canMove) { - auto request = MakeOperationRequest(settings); - request.set_table(TStringType{table}); + Ydb::Table::BulkUpsertRequest* request = nullptr; + std::unique_ptr holder; + + if (settings.Arena_) { + request = MakeOperationRequestOnArena(settings, settings.Arena_); + } else { + holder = std::make_unique(MakeOperationRequest(settings)); + request = holder.get(); + } + + request->set_table(TStringType{table}); if (canMove) { - request.mutable_rows()->mutable_type()->Swap(&rows.GetType().GetProto()); - request.mutable_rows()->mutable_value()->Swap(&rows.GetProto()); + request->mutable_rows()->mutable_type()->Swap(&rows.GetType().GetProto()); + request->mutable_rows()->mutable_value()->Swap(&rows.GetProto()); } else { - *request.mutable_rows()->mutable_type() = TProtoAccessor::GetProto(rows.GetType()); - *request.mutable_rows()->mutable_value() = rows.GetProto(); + *request->mutable_rows()->mutable_type() = TProtoAccessor::GetProto(rows.GetType()); + *request->mutable_rows()->mutable_value() = rows.GetProto(); } auto promise = NewPromise(); + auto extractor = [promise](google::protobuf::Any* any, TPlainStatus status) mutable { + Y_UNUSED(any); + TBulkUpsertResult val(TStatus(std::move(status))); + promise.SetValue(std::move(val)); + }; - auto extractor = [promise] - (google::protobuf::Any* any, TPlainStatus status) mutable { - Y_UNUSED(any); - TBulkUpsertResult val(TStatus(std::move(status))); - promise.SetValue(std::move(val)); - }; - - Connections_->RunDeferred( - std::move(request), - extractor, - &Ydb::Table::V1::TableService::Stub::AsyncBulkUpsert, - DbDriverState_, - INITIAL_DEFERRED_CALL_DELAY, - TRpcRequestSettings::Make(settings)); - + if (settings.Arena_) { + Connections_->RunDeferred( + request, + extractor, + &Ydb::Table::V1::TableService::Stub::AsyncBulkUpsert, + DbDriverState_, + INITIAL_DEFERRED_CALL_DELAY, + TRpcRequestSettings::Make(settings)); + } else { + Connections_->RunDeferred( + std::move(*holder), + extractor, + &Ydb::Table::V1::TableService::Stub::AsyncBulkUpsert, + DbDriverState_, + INITIAL_DEFERRED_CALL_DELAY, + TRpcRequestSettings::Make(settings)); + } return promise.GetFuture(); } diff --git a/ydb/public/sdk/cpp/src/client/value/value.cpp b/ydb/public/sdk/cpp/src/client/value/value.cpp index 2516f3a42264..47ae325b35ff 100644 --- a/ydb/public/sdk/cpp/src/client/value/value.cpp +++ b/ydb/public/sdk/cpp/src/client/value/value.cpp @@ -1046,14 +1046,31 @@ class TValue::TImpl { public: TImpl(const TType& type, const Ydb::Value& valueProto) : Type_(type) - , ProtoValue_(valueProto) {} + , ProtoValue_(valueProto) + , ArenaAllocatedValueProto_(nullptr) {} TImpl(const TType& type, Ydb::Value&& valueProto) : Type_(type) - , ProtoValue_(std::move(valueProto)) {} + , ProtoValue_(std::move(valueProto)) + , ArenaAllocatedValueProto_(nullptr) {} + + TImpl(const TType& type, Ydb::Value* arenaAllocatedValueProto) + : Type_(type) + , ProtoValue_{} + , ArenaAllocatedValueProto_(arenaAllocatedValueProto) {} + + const Ydb::Value& GetProto() const { + return ArenaAllocatedValueProto_ ? *ArenaAllocatedValueProto_ : ProtoValue_; + } + + Ydb::Value& GetProto() { + return ArenaAllocatedValueProto_ ? *ArenaAllocatedValueProto_ : ProtoValue_; + } TType Type_; +private: Ydb::Value ProtoValue_; + Ydb::Value* ArenaAllocatedValueProto_; }; //////////////////////////////////////////////////////////////////////////////// @@ -1064,6 +1081,9 @@ TValue::TValue(const TType& type, const Ydb::Value& valueProto) TValue::TValue(const TType& type, Ydb::Value&& valueProto) : Impl_(new TImpl(type, std::move(valueProto))) {} +TValue::TValue(const TType& type, Ydb::Value* arenaAllocatedValueProto) + : Impl_(new TImpl(type, arenaAllocatedValueProto)) {} + const TType& TValue::GetType() const { return Impl_->Type_; } @@ -1073,11 +1093,11 @@ TType & TValue::GetType() { } const Ydb::Value& TValue::GetProto() const { - return Impl_->ProtoValue_; + return Impl_->GetProto(); } Ydb::Value& TValue::GetProto() { - return Impl_->ProtoValue_; + return Impl_->GetProto(); } //////////////////////////////////////////////////////////////////////////////// @@ -1104,7 +1124,7 @@ class TValueParser::TImpl { : Value_(value.Impl_) , TypeParser_(value.GetType()) { - Reset(Value_->ProtoValue_); + Reset(Value_->GetProto()); } TImpl(const TType& type) @@ -2781,7 +2801,6 @@ class TValueBuilderImpl { } private: - //TTypeBuilder TypeBuilder_; TTypeBuilder::TImpl TypeBuilder_; Ydb::Value ProtoValue_;