Skip to content

[Backport] Add OnBeforeRequest callback (#8541) #8543

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Elastic.Transport" Version="0.8.0" />
<PackageReference Include="Elastic.Transport" Version="0.8.1" />
<PackageReference Include="PolySharp" Version="1.15.0">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,9 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text.Json;
using System.Threading.Tasks;
using System.Threading;
using Elastic.Transport;
using Elastic.Transport.Diagnostics;

using Elastic.Clients.Elasticsearch.Requests;

Expand All @@ -28,18 +25,23 @@ public partial class ElasticsearchClient
private const string OpenTelemetrySchemaVersion = "https://opentelemetry.io/schemas/1.21.0";

private readonly ITransport<IElasticsearchClientSettings> _transport;
internal static ConditionalWeakTable<JsonSerializerOptions, IElasticsearchClientSettings> SettingsTable { get; } = new();

/// <summary>
/// Creates a client configured to connect to http://localhost:9200.
/// </summary>
public ElasticsearchClient() : this(new ElasticsearchClientSettings(new Uri("http://localhost:9200"))) { }
public ElasticsearchClient() :
this(new ElasticsearchClientSettings(new Uri("http://localhost:9200")))
{
}

/// <summary>
/// Creates a client configured to connect to a node reachable at the provided <paramref name="uri" />.
/// </summary>
/// <param name="uri">The <see cref="Uri" /> to connect to.</param>
public ElasticsearchClient(Uri uri) : this(new ElasticsearchClientSettings(uri)) { }
public ElasticsearchClient(Uri uri) :
this(new ElasticsearchClientSettings(uri))
{
}

/// <summary>
/// Creates a client configured to communicate with Elastic Cloud using the provided <paramref name="cloudId" />.
Expand All @@ -51,8 +53,8 @@ public ElasticsearchClient(Uri uri) : this(new ElasticsearchClientSettings(uri))
/// </summary>
/// <param name="cloudId">The Cloud ID of an Elastic Cloud deployment.</param>
/// <param name="credentials">The credentials to use for the connection.</param>
public ElasticsearchClient(string cloudId, AuthorizationHeader credentials) : this(
new ElasticsearchClientSettings(cloudId, credentials))
public ElasticsearchClient(string cloudId, AuthorizationHeader credentials) :
this(new ElasticsearchClientSettings(cloudId, credentials))
{
}

Expand All @@ -69,8 +71,7 @@ internal ElasticsearchClient(ITransport<IElasticsearchClientSettings> transport)
{
transport.ThrowIfNull(nameof(transport));
transport.Configuration.ThrowIfNull(nameof(transport.Configuration));
transport.Configuration.RequestResponseSerializer.ThrowIfNull(
nameof(transport.Configuration.RequestResponseSerializer));
transport.Configuration.RequestResponseSerializer.ThrowIfNull(nameof(transport.Configuration.RequestResponseSerializer));
transport.Configuration.Inferrer.ThrowIfNull(nameof(transport.Configuration.Inferrer));

_transport = transport;
Expand All @@ -96,47 +97,38 @@ private enum ProductCheckStatus

private partial void SetupNamespaces();

internal TResponse DoRequest<TRequest, TResponse, TRequestParameters>(TRequest request)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new() =>
DoRequest<TRequest, TResponse, TRequestParameters>(request, null);

internal TResponse DoRequest<TRequest, TResponse, TRequestParameters>(
TRequest request,
Action<IRequestConfiguration>? forceConfiguration)
TRequest request)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
=> DoRequestCoreAsync<TRequest, TResponse, TRequestParameters>(false, request, forceConfiguration).EnsureCompleted();

internal Task<TResponse> DoRequestAsync<TRequest, TResponse, TRequestParameters>(
TRequest request,
CancellationToken cancellationToken = default)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
=> DoRequestAsync<TRequest, TResponse, TRequestParameters>(request, null, cancellationToken);
{
return DoRequestCoreAsync<TRequest, TResponse, TRequestParameters>(false, request).EnsureCompleted();
}

internal Task<TResponse> DoRequestAsync<TRequest, TResponse, TRequestParameters>(
TRequest request,
Action<IRequestConfiguration>? forceConfiguration,
CancellationToken cancellationToken = default)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
=> DoRequestCoreAsync<TRequest, TResponse, TRequestParameters>(true, request, forceConfiguration, cancellationToken).AsTask();
{
return DoRequestCoreAsync<TRequest, TResponse, TRequestParameters>(true, request, cancellationToken).AsTask();
}

private ValueTask<TResponse> DoRequestCoreAsync<TRequest, TResponse, TRequestParameters>(
bool isAsync,
TRequest request,
Action<IRequestConfiguration>? forceConfiguration,
CancellationToken cancellationToken = default)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
{
// The product check modifies request parameters and therefore must not be executed concurrently.
if (request is null)
{
throw new ArgumentNullException(nameof(request));
}

// We use a lockless CAS approach to make sure that only a single product check request is executed at a time.
// We do not guarantee that the product check is always performed on the first request.

Expand All @@ -157,12 +149,12 @@ private ValueTask<TResponse> DoRequestCoreAsync<TRequest, TResponse, TRequestPar

ValueTask<TResponse> SendRequest()
{
var (endpointPath, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request);
var openTelemetryDataMutator = GetOpenTelemetryDataMutator<TRequest, TRequestParameters>(request, resolvedRouteValues);
PrepareRequest<TRequest, TRequestParameters>(request, out var endpointPath, out var postData, out var requestConfiguration, out var routeValues);
var openTelemetryDataMutator = GetOpenTelemetryDataMutator<TRequest, TRequestParameters>(request, routeValues);

return isAsync
? new ValueTask<TResponse>(_transport.RequestAsync<TResponse>(endpointPath, postData, openTelemetryDataMutator, request.RequestConfiguration, cancellationToken))
: new ValueTask<TResponse>(_transport.Request<TResponse>(endpointPath, postData, openTelemetryDataMutator, request.RequestConfiguration));
? new ValueTask<TResponse>(_transport.RequestAsync<TResponse>(endpointPath, postData, openTelemetryDataMutator, requestConfiguration, cancellationToken))
: new ValueTask<TResponse>(_transport.Request<TResponse>(endpointPath, postData, openTelemetryDataMutator, requestConfiguration));
}

async ValueTask<TResponse> SendRequestWithProductCheck()
Expand All @@ -178,34 +170,35 @@ async ValueTask<TResponse> SendRequestWithProductCheck()
// 32-bit read/write operations are atomic and due to the initial memory barrier, we can be sure that
// no other thread executes the product check at the same time. Locked access is not required here.
if (_productCheckStatus is (int)ProductCheckStatus.InProgress)
{
_productCheckStatus = (int)ProductCheckStatus.NotChecked;
}

throw;
}
}

async ValueTask<TResponse> SendRequestWithProductCheckCore()
{
PrepareRequest<TRequest, TRequestParameters>(request, out var endpointPath, out var postData, out var requestConfiguration, out var routeValues);
var openTelemetryDataMutator = GetOpenTelemetryDataMutator<TRequest, TRequestParameters>(request, routeValues);

// Attach product check header

// TODO: The copy constructor should accept null values
var requestConfig = (request.RequestConfiguration is null)
? new RequestConfiguration()
var requestConfig = (requestConfiguration is null)
? new RequestConfiguration
{
ResponseHeadersToParse = new HeadersList("x-elastic-product")
}
: new RequestConfiguration(request.RequestConfiguration)
: new RequestConfiguration(requestConfiguration)
{
ResponseHeadersToParse = (request.RequestConfiguration.ResponseHeadersToParse is { Count: > 0 })
? new HeadersList(request.RequestConfiguration.ResponseHeadersToParse, "x-elastic-product")
ResponseHeadersToParse = (requestConfiguration.ResponseHeadersToParse is { Count: > 0 })
? new HeadersList(requestConfiguration.ResponseHeadersToParse, "x-elastic-product")
: new HeadersList("x-elastic-product")
};

// Send request

var (endpointPath, resolvedRouteValues, postData) = PrepareRequest<TRequest, TRequestParameters>(request);
var openTelemetryDataMutator = GetOpenTelemetryDataMutator<TRequest, TRequestParameters>(request, resolvedRouteValues);

TResponse response;

if (isAsync)
Expand Down Expand Up @@ -239,7 +232,9 @@ async ValueTask<TResponse> SendRequestWithProductCheckCore()
: (int)ProductCheckStatus.Failed;

if (_productCheckStatus == (int)ProductCheckStatus.Failed)
{
throw new UnsupportedProductException(UnsupportedProductException.InvalidProductError);
}

return response;
}
Expand All @@ -249,15 +244,17 @@ async ValueTask<TResponse> SendRequestWithProductCheckCore()
where TRequest : Request<TRequestParameters>
where TRequestParameters : RequestParameters, new()
{
// If there are no subscribed listeners, we avoid some work and allocations
// If there are no subscribed listeners, we avoid some work and allocations.
if (!Elastic.Transport.Diagnostics.OpenTelemetry.ElasticTransportActivitySourceHasListeners)
{
return null;
}

return OpenTelemetryDataMutator;

void OpenTelemetryDataMutator(Activity activity)
{
// We fall back to a general operation name in cases where the derived request fails to override the property
// We fall back to a general operation name in cases where the derived request fails to override the property.
var operationName = !string.IsNullOrEmpty(request.OperationName) ? request.OperationName : request.HttpMethod.GetStringValue();

// TODO: Optimisation: We should consider caching these, either for cases where resolvedRouteValues is null, or
Expand All @@ -267,7 +264,7 @@ void OpenTelemetryDataMutator(Activity activity)
// The latter may bloat the cache as some combinations of path parts may rarely re-occur.

activity.DisplayName = operationName;

activity.SetTag(OpenTelemetry.SemanticConventions.DbOperation, !string.IsNullOrEmpty(request.OperationName) ? request.OperationName : "unknown");
activity.SetTag($"{OpenTelemetrySpanAttributePrefix}schema_url", OpenTelemetrySchemaVersion);

Expand All @@ -282,21 +279,26 @@ void OpenTelemetryDataMutator(Activity activity)
}
}

private (EndpointPath endpointPath, Dictionary<string, string>? resolvedRouteValues, PostData data) PrepareRequest<TRequest, TRequestParameters>(TRequest request)
private void PrepareRequest<TRequest, TRequestParameters>(
TRequest request,
out EndpointPath endpointPath,
out PostData? postData,
out IRequestConfiguration? requestConfiguration,
out Dictionary<string, string>? routeValues)
where TRequest : Request<TRequestParameters>
where TRequestParameters : RequestParameters, new()
{
request.ThrowIfNull(nameof(request), "A request is required.");

var (resolvedUrl, _, routeValues) = request.GetUrl(ElasticsearchClientSettings);
var (resolvedUrl, _, resolvedRouteValues) = request.GetUrl(ElasticsearchClientSettings);
var pathAndQuery = request.RequestParameters.CreatePathWithQueryStrings(resolvedUrl, ElasticsearchClientSettings);

var postData =
request.HttpMethod == HttpMethod.GET ||
request.HttpMethod == HttpMethod.HEAD || !request.SupportsBody
routeValues = resolvedRouteValues;
endpointPath = new EndpointPath(request.HttpMethod, pathAndQuery);
postData =
request.HttpMethod is HttpMethod.GET or HttpMethod.HEAD || !request.SupportsBody
? null
: PostData.Serializable(request);

return (new EndpointPath(request.HttpMethod, pathAndQuery), routeValues, postData);
requestConfiguration = request.RequestConfiguration;
ElasticsearchClientSettings.OnBeforeRequest?.Invoke(this, request, endpointPath, ref postData, ref requestConfiguration);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,64 +5,55 @@
using System;
using System.Threading;
using System.Threading.Tasks;

using Elastic.Clients.Elasticsearch.Requests;
using Elastic.Transport;
using Elastic.Transport.Products.Elasticsearch;

namespace Elastic.Clients.Elasticsearch;

public abstract class NamespacedClientProxy
{
private const string InvalidOperation = "The client has not been initialised for proper usage as may have been partially mocked. Ensure you are using a " +
private const string InvalidOperation =
"The client has not been initialised for proper usage as may have been partially mocked. Ensure you are using a " +
"new instance of ElasticsearchClient to perform requests over a network to Elasticsearch.";

protected ElasticsearchClient Client { get; }

/// <summary>
/// Initializes a new instance for mocking.
/// </summary>
protected NamespacedClientProxy() { }
protected NamespacedClientProxy()
{
}

internal NamespacedClientProxy(ElasticsearchClient client) => Client = client;

internal TResponse DoRequest<TRequest, TResponse, TRequestParameters>(TRequest request)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
=> DoRequest<TRequest, TResponse, TRequestParameters>(request, null);

internal TResponse DoRequest<TRequest, TResponse, TRequestParameters>(
TRequest request,
Action<IRequestConfiguration>? forceConfiguration)
TRequest request)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
{
if (Client is null)
ThrowHelper.ThrowInvalidOperationException(InvalidOperation);
{
throw new InvalidOperationException(InvalidOperation);
}

return Client.DoRequest<TRequest, TResponse, TRequestParameters>(request, forceConfiguration);
return Client.DoRequest<TRequest, TResponse, TRequestParameters>(request);
}

internal Task<TResponse> DoRequestAsync<TRequest, TResponse, TRequestParameters>(
TRequest request,
CancellationToken cancellationToken = default)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
=> DoRequestAsync<TRequest, TResponse, TRequestParameters>(request, null, cancellationToken);

internal Task<TResponse> DoRequestAsync<TRequest, TResponse, TRequestParameters>(
TRequest request,
Action<IRequestConfiguration>? forceConfiguration,
CancellationToken cancellationToken = default)
where TRequest : Request<TRequestParameters>
where TResponse : TransportResponse, new()
where TRequestParameters : RequestParameters, new()
{
if (Client is null)
ThrowHelper.ThrowInvalidOperationException(InvalidOperation);
{
throw new InvalidOperationException(InvalidOperation);
}

return Client.DoRequestAsync<TRequest, TResponse, TRequestParameters>(request, forceConfiguration, cancellationToken);
return Client.DoRequestAsync<TRequest, TResponse, TRequestParameters>(request, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Runtime.InteropServices;

using Elastic.Clients.Elasticsearch.Esql;

using Elastic.Clients.Elasticsearch.Requests;
using Elastic.Clients.Elasticsearch.Serialization;

using Elastic.Transport;
Expand Down Expand Up @@ -110,6 +111,7 @@ public abstract class ElasticsearchClientSettingsBase<TConnectionSettings> :
private readonly FluentDictionary<MemberInfo, PropertyMapping> _propertyMappings = new();
private readonly FluentDictionary<Type, string> _routeProperties = new();
private readonly Serializer _sourceSerializer;
private BeforeRequestEvent? _onBeforeRequest;
private bool _experimentalEnableSerializeNullInferredValues;
private ExperimentalSettings _experimentalSettings = new();

Expand Down Expand Up @@ -158,7 +160,7 @@ protected ElasticsearchClientSettingsBase(

FluentDictionary<Type, string> IElasticsearchClientSettings.RouteProperties => _routeProperties;
Serializer IElasticsearchClientSettings.SourceSerializer => _sourceSerializer;

BeforeRequestEvent? IElasticsearchClientSettings.OnBeforeRequest => _onBeforeRequest;
ExperimentalSettings IElasticsearchClientSettings.Experimental => _experimentalSettings;

bool IElasticsearchClientSettings.ExperimentalEnableSerializeNullInferredValues => _experimentalEnableSerializeNullInferredValues;
Expand Down Expand Up @@ -322,6 +324,20 @@ public TConnectionSettings DefaultMappingFor(IEnumerable<ClrTypeMapping> typeMap

return (TConnectionSettings)this;
}

/// <inheritdoc cref="IElasticsearchClientSettings.OnBeforeRequest"/>
public TConnectionSettings OnBeforeRequest(BeforeRequestEvent handler)
{
return Assign(handler, static (a, v) => a._onBeforeRequest += v ?? DefaultBeforeRequestHandler);
}

private static void DefaultBeforeRequestHandler(ElasticsearchClient client,
Request request,
EndpointPath endpointPath,
ref PostData? postData,
ref IRequestConfiguration? requestConfiguration)
{
}
}

/// <inheritdoc cref="TransportClientConfigurationValues" />
Expand Down
Loading
Loading