diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 2d3fd78fbd..334529b6dc 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -73,6 +73,7 @@ "seekable", "servicebus", "stylesheet", + "subclient", "telemetered", "typespec", "undelete", @@ -208,4 +209,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/Cargo.lock b/Cargo.lock index ece05b121b..15ef7291eb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -160,6 +160,7 @@ version = "0.27.0" dependencies = [ "async-lock", "async-trait", + "azure_core_macros", "azure_core_test", "azure_identity", "azure_security_keyvault_certificates", @@ -203,18 +204,36 @@ dependencies = [ "typespec_macros", ] +[[package]] +name = "azure_core_macros" +version = "0.1.0" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "tokio", + "tracing", + "tracing-subscriber", + "typespec_client_core", +] + [[package]] name = "azure_core_opentelemetry" version = "0.1.0" dependencies = [ "azure_core", - "log", + "azure_core_test", + "azure_core_test_macros", + "azure_identity", "opentelemetry", + "opentelemetry-http", "opentelemetry_sdk", + "reqwest", "tokio", "tracing", "tracing-subscriber", "typespec_client_core", + "url", ] [[package]] @@ -1718,6 +1737,18 @@ dependencies = [ "tracing", ] +[[package]] +name = "opentelemetry-http" +version = "0.30.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50f6639e842a97dbea8886e3439710ae463120091e2e064518ba8e716e6ac36d" +dependencies = [ + "async-trait", + "bytes", + "http", + "opentelemetry", +] + [[package]] name = "opentelemetry_sdk" version = "0.30.0" diff --git a/Cargo.toml b/Cargo.toml index 7b4a2b7d91..19258d493e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "sdk/typespec/typespec_macros", "sdk/core/azure_core", "sdk/core/azure_core_amqp", + "sdk/core/azure_core_macros", "sdk/core/azure_core_test", "sdk/core/azure_core_test_macros", "sdk/core/azure_core_opentelemetry", @@ -49,6 +50,11 @@ path = "sdk/typespec/typespec_macros" version = "0.27.0" path = "sdk/core/azure_core" +[workspace.dependencies.azure_core_macros] +version = "0.1.0" +path = "sdk/core/azure_core_macros" + + [workspace.dependencies.azure_core_amqp] version = "0.6.0" path = "sdk/core/azure_core_amqp" diff --git a/doc/distributed-tracing-for-rust-service-clients.md b/doc/distributed-tracing-for-rust-service-clients.md new file mode 100644 index 0000000000..80093faa4d --- /dev/null +++ b/doc/distributed-tracing-for-rust-service-clients.md @@ -0,0 +1,414 @@ + + + +# Distributed tracing options in Rust service clients + +## Distributed tracing fundamentals + +There are three core constructs which are used in distributed tracing: + +* Tracer Providers +* Tracers +* Spans + +### Tracer Provider + +The job of a "Tracer Provider" is to be a factory for tracers. It is the "gateway" construct for distributed tracing. + +### Tracer + +A "tracer" is a factory for "Spans". A `Tracer` is configured with three parameters: + +* `namespace` - the "namespace" for the service client. The namespace for all azure services are listed [on this page](https://learn.microsoft.com/azure/azure-resource-manager/management/azure-services-resource-providers). +* `package name` - this is typically the Cargo package name for the service client (`env!("CARGO_PKG_NAME")`) +* `package version` - this is typically the version of the Cargo package for the service client (`env!("CARGO_PKG_VERSION")`) +* `Schema Url` - this is typically the OpenTelemetry schema version - if not provided, a default schema version is used. + +#### Note + +Custom Schema Url support is not currently implemented. + +Tracers have two mechanisms for creating spans: + +* Create a new child span from a parent span. +* Create a new child span from the "current" span (where the "current" span is tracer implementation specific). + +### Span + +A "Span" is a unit of tracing. Each span has the following attributes: + +* "name" - the "name" of the span. For public APIs, this is typically the name of the public API, for HTTP request, it is typically the HTTP verb. +* "kind" - HTTP spans come in several flavours: + * Internal - the "default" for span kinds. + * Client - represents a client application - HTTP request spans are "Client" spans. + * Server - represents a server - Azure SDK clients will never use these. + * Producer - represents a messaging (Event Hubs and Service Bus) message producer. + * Consumer - represents a message consumer. +* "status" - A span status is either "Unset" or "Error" - OpenTelemetry defines a status of "Ok" in addition to these, but it is reserved for client applications. +* "attributes" - the attributes on a span describe the span. Attributes include: + * "az.namespace" - the namespace of a request. + * "url.full" - the full (sanitized) URL for an HTTP request + * "server.address" - the DNS address of the HTTP server + * "http.request.method" - the HTTP method used for the request ("GET", "PUT" etc). + +## Azure Distributed Tracing requirements + +Azure's distributed tracing requirements are laid out in a number of documents: + +* [Azure Distributed Tracing Conventions](https://github.com/Azure/azure-sdk/blob/main/docs/tracing/distributed-tracing-conventions.md) +* [Azure Distributed Tracing Implementation](https://github.com/Azure/azure-sdk/blob/main/docs/general/implementation.md#distributed-tracing) +* [Open Telemetry HTTP Span Conventions](https://opentelemetry.io/docs/specs/semconv/http/http-spans/) + +Looking at each document, the following two requirements for distributed tracing clients fall out: + +1) Each public API (service client function) needs to have a span with the `az.namespace` attribute added - the az.attribute (as defined above). [See this for more information](https://github.com/Azure/azure-sdk/blob/main/docs/tracing/distributed-tracing-conventions.md#public-api-calls). +1) Each HTTP request needs to have a span with the same `az.namespace` attribute and a number of other attributes derived from the HTTP operation. [See this for more information](https://github.com/Azure/azure-sdk/blob/main/docs/tracing/distributed-tracing-conventions.md#http-client-spans). The HTTP request span should be a child of a public API span if possible. + +Implementations are allowed to skip adding the `az.namespace` attribute but it is strongly discouraged. + +It turns out that in OpenTelemetry, an `OpenTelemetry::Tracer` is constructed with an `InstrumentationScope` which allows arbitrary attributes to be attached to the tracer, which is also attached to each span constructed from the tracer. As such, it makes sense for each service client to have a `Tracer` attached to the service client, and this `Tracer` can be used to hold the namespace attribute. This architecture is reflected in the distributed tracing wrapper API, the `Tracer` trait contains a `namespace()` function. + +## Additional requirements + +For public APIs, the rule of thumb is: "If the operation does not take time and cannot fail, it doesn't get a span". For most public APIs, this isn't a huge deal, but for pageable and long running operations, it changes how the code is generated. Specifically, for pageables, the actual service client does not actually interact with the network and cannot fail, but the individual pager returned does interact with the network and can fail - thus the pager's interactions with the service need to be instrumented with a span. Long Running Operations behave similarly - while the original API is instrumented with a span, the same is true for the status monitor - it also needs to be instrumented with a span whose name matches the name of the original API. + +In addition, [certain service clients](https://github.com/Azure/azure-sdk/blob/main/docs/tracing/distributed-tracing-conventions.md#library-specific-attributes) (Cosmos DB, KeyVault, etc) have additional client-specific attributes which need to be added to the span. + +## Core API design + +Given this architecture, it implies that each service client needs the following: + +1) A struct field named `tracer` which is an `Arc` which holds the tracing implementation specific tracer. +2) Code in the service client's `new` function which instantiates a `tracer` from the `TracerProvider` configured in the service client options. The primary function of this tracer is to provide the value for the `az.namespace` attribute for both the public API spans and the HTTP request spans. +3) Code in each service client public API which instantiates a public API span. + +For the Rust implementation, if a tracer provider is configured, ALL http operations will have HTTP request spans created regardless of whether the public API spans are created. + +## Implementation details + +To provide for requirement #1, if a customer provides a value for the `azure_core::ClientOptions::request_instrumentation` structure, the Azure Core HTTP pipeline will add a `PublicApiInstrumentationPolicy` to the pipeline which is responsible for creating the public API outer span. + +To provide for requirement #2, if a customer provides a `azure_core::ClientOptions::request_instrumentation` the `azure_core` HTTP pipeline will add a `RequestInstrumentationPolicy` to the pipeline which is responsible for creating the required HTTP request span to the pipeline. + +This implementation means that operations like Long Running Operations (Pollers) and Pageable Operations (Pagers) will all have a Public API span created by the `PublicApiInstrumentationPolicy` and a HTTP Request span created by the `RequestInstrumentationPolicy`. + +### Pipeline Construction + +When an `azure_core::http::Pipeline` is constructed, if the client options include a tracing provider, then the pipeline will create a tracer from that tracer provider with the crate name and crate version (which are both parameters to the pipeline constructor). This tracer will have a namespace of "None" and acts as a fallback in case the public APIs don't provide a `Tracer` implementation (if, for example public APIs are instrumented, but the service client itself is not instrumented). This tracer will be passed to both of the tracing policies. + +### PublicApiInstrumentationPolicy + +1) If the pipeline context has a `Arc` attached to the context, then the public API policy will simply call the next policy in the pipeline, because a span in the pipeline indicates that this API call is a nested API call. +1) If the context does not have a `PublicApiInstrumentationInformation` attached to it, the policy will call the next policy in the pipeline, otherwise the policy will: + 1) Look for an `Arc` attached to the context. If one is found, it uses that tracer, otherwise it uses a tracer attached to the policy. + 1) Create a span with a name matching the `name` in the [`PublicApiInstrumentationInformation`] structure and attributes from the attributes attached to the `PublicApiInstrumentationInformation`. It will also add the `az.namespace` attribute to the span if the tracer has a namespace associated with it (this will typically only be the case for tracers attached to the context). + 1) Once the span has been created, the policy will attach the newly created span to the context so other policies in the pipeline (specifically the `RequestInstrumentationPolicy` can use this span). +1) Once the span has been created, the policy calls the next policy in the pipeline. +1) After the remaining policies in the pipeline have run, the policy inspects the `Result` of the pipeline execution and sets the `error.type` attribute and the span status based on the `Result` of the operation. + +### RequestInstrumentationPolicy + +The `RequestInstrumentationPolicy` will do the following: + +1) If the `Context` parameter for the `RequestInstrumentationPolicy` contains a `Tracer` value, then the `RequestInstrumentationPolicy` will use that `Tracer` value to create the span, otherwise it will use the pre-configured tracer from when the policy was created. +2) If the `Context` parameter for the `RequestInstrumentationPolicy` contains a `Span` value, then the policy will use that span as the parent span for the newly created HTTP request span, otherwise it will create a new span. + +This design means that even if a service public API is not fully instrumented with a `Tracer` or a `Span`, it will still generate some HTTP request traces. + +Since the namespace attribute is service-client wide, it makes sense to capture that inside a per-service client field, that way it can be easily accessed from service clients. + +## Convenience Macros + +To facilitate the implementation of the three core requirements above, three attribute-like macros are defined for the use of each service client. + +NOTE: These attributes are only for client library development and are not intended for external customers use - they depend heavily on code which follows the Rust API design guidelines. + +Those macros are: + +* `#[tracing::client]` - applied to each service client `struct` declaration. +* `#[tracing::new]` - applied to each service client "constructor" function. +* `#[tracing::function]` - applied to each service client "public API". +* `#[tracing::subclient]` - applied to a subclient "constructor" function. + +### `#[tracing::client]` + +The `tracing::client` attribute macro does one thing and one thing only: It defines a field named `tracer` which is added to the list of fields in the service client structure. + +#### Modification introduced by this macro + +From: + +```rust +pub struct MyServiceClient { + endpoint: Url, +} +``` + +to + +```diff +pub struct MyServiceClient { + endpoint: Url, ++ tracer: std::sync::Arc, +} +``` + +Arguably this attribute is unnecessary, but it can be incredibly helpful especially if we need to add more elements to each service client in the future. + +### `#[tracing::new()]` + +Annotates a `new` service client function to initialize the `tracer` field in the structure. + +#### Modification introduced by this macro + +from: + +```rust +pub fn new( + endpoint: &str, + credential: Arc, + options: Option, +) -> Result { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + let auth_policy: Arc = Arc::new(BearerTokenCredentialPolicy::new( + credential, + vec!["https://vault.azure.net/.default"], + )); + Ok(Self { + endpoint, + api_version: options.api_version, + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.client_options, + Vec::default(), + vec![auth_policy], + ), + }) +} +``` + +to: + +```diff +pub fn new( + endpoint: &str, + credential: Arc, + options: Option +) -> Result { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + let auth_policy: Arc = Arc::new(BearerTokenCredentialPolicy::new( + credential, + vec!["https://vault.azure.net/.default"], + )); ++ let tracer = ++ if let Some(tracer_options) = &options.client_options.request_instrumentation { ++ tracer_options ++ .tracer_provider ++ .as_ref() ++ .map(|tracer_provider| { ++ tracer_provider.get_tracer( ++ Some(#client_namespace), ++ option_env!("CARGO_PKG_NAME").unwrap_or("UNKNOWN"), ++ option_env!("CARGO_PKG_VERSION").unwrap_or("UNKNOWN"), ++ ) ++ }) ++ } else { ++ None ++ }; + Ok(Self { ++ tracer, + endpoint, + api_version: options.api_version, + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.client_options, + Vec::default(), + vec![auth_policy], + ), + }) +} +``` + +Note that if the service client uses the `builder` pattern, it cannot use the `tracing::new` attribute. + +### `#[tracing::function(.)]` + +Applied to all public functions in the service client ("public APIs" in distributed tracing terms). This macro creates the client span for each service client method, and updates the client span if appropriate. + +Note that the `` and `` should be the values from the client typespec - that way the public API span names align for all client languages. + +#### Modification introduced by this macro + +From: + +```rust +pub async fn get( + &self, + path: &str, + options: Option>, +) -> Result { + let options = options.unwrap_or_default(); + let mut url = self.endpoint.clone(); + url.set_path(path); + url.query_pairs_mut() + .append_pair("api-version", &self.api_version); + + let mut request = Request::new(url, azure_core::http::Method::Get); + + let response = self + .pipeline + .send(&options.method_options.context, &mut request) + .await?; + if !response.status().is_success() { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + }, + format!("Failed to GET {}: {}", request.url(), response.status()))); + } + Ok(response) +} +``` + +To: + +```diff +pub async fn get( + &self, + path: &str, + options: Option>, +) -> Result { ++ let options = { ++ let mut options = options.unwrap_or_default(); ++ let public_api_info = azure_core::tracing::PublicApiInstrumentationInformation { ++ api_name: "TestFunction", ++ attributes: Vec::new(), ++ }; ++ let mut ctx = options.method_options.context.with_value(public_api_info); ++ if let Some(tracer) = &self.tracer { ++ ctx = ctx.with_value(tracer.clone()); ++ } ++ options.method_options.context = ctx; ++ Some(options) ++ }; + let mut url = self.endpoint.clone(); + url.set_path(path); + url.query_pairs_mut() + .append_pair("api-version", &self.api_version); + + let mut request = Request::new(url, azure_core::http::Method::Get); + + let response = self + .pipeline +- .send(&options.method_options.context, &mut request) ++ .send(&ctx, &mut request) + .await?; + if !response.status().is_success() { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + }, + format!("Failed to GET {}: {}", request.url(), response.status()))); + } + response +} +``` + +The `tracing::function` has a separate form which can be used for service clients which +implement per-service-client distributed tracing attributes. + +The parameters for the `tracing::function` roughly follow the following BNF: + +```bnf +tracing_parameters = quoted_string [ ',' '( attribute_list ')'] +quoted_string = `"` `"` +attribute_list = attribute | attribute [`,`] attribute_list +attribute = key '=' value +key = identifier | quoted_string +identifier = rust-identifier | rust-identifier '.' identifier +rust-identifier = +value = +``` + +This means that the following are valid parameters for `tracing::function`: + +* `#[tracing::function("MyServiceClient.MyApi")]` - specifies a public API name. +* `#[tracing::function("Name", (az.namespace="namespace"))]` - specifies a public API name and creates a span with an attribute named "az.namespace" and a value of "namespace". +* `#[tracing::function("Name", (api_count=23, "my_attribute" = "Abc"))]` - specifies a public API name and creates a span with two attributes, one named "api_count" with a value of "23" and the other with the name "my_attribute" and a value of "Abc" +* `#[tracing::function("Name", ("API path"=path))]` - specifies a public API name and creates a span with an attribute named "API path" and the value of the parameter named "path". + +This allows a generator to pass in simple attribute annotations to the public API spans created by the pipeline. + +## Special cases + +For the most part, the three tracing attribute macros provide functionality that should allow most generated and wrapped clients to create all the required distributed tracing span attributes. + +However there are some cases where having additional control over the traces is needed. This functionality is primarily intended for wrapped service clients to handle span attributes which cannot be easily determined from the operation. + +### Service Client needs to add attributes *before* the pipeline + +If your service client needs to define attributes in the client span before the pipeline and the attributes cannot be determined by reflecting the contents of service parameters, then the service client can create its own `PublicApiInstrumentationInformation` structure and add the desired attributes manually. If this `PublicApiInstrumentationInformation` is added to the request Context, it will be reflected in the spans created by the `PublicApiInstrumentationPolicy`. + +### Service Client needs to add attributes before and after the pipeline + +For some operations, a service client cannot easily express the information needed to generate the span (or needs to add attributes based on the response to the public API). In that case, the service client should create its own span. + +The `PublicApiInstrumentationPolicy` struct has a convenience method `create_public_api_span` which allows a service client to create a span based on the current context. The function signature for `create_public_api_span` is `create_public_api_span(ctx: &Context, tracer: Option>) -> Option>`. It assumes that the `Context`in `ctx` has already had a `PublicApiInstrumentationInformation` attribute added and returns an optional span - if the span has the value of Some, it is a tracer which can be used for the public API, if it has the value of None, then there is no public API span created (this can happen if there is no `PublicApiInstrumentationInformation` provided, or if the `Context` already contains a public API span). + +The client can then add whatever attributes to the span it needs, and after the pipeline has run, can add any additional attributes to the span. + +Note that in this model, the client is responsible for ending the span. + +### Service implementations with "subclients" + +Service clients can sometimes contain "subclients" - clients which have their own pipelines and endpoint which contain subclient specific functionality. + +Such subclients often have an accessor function to create a new subclient instance which looks like this: + +```rust + +pub fn get_operation_templates_lro_client(&self) -> OperationTemplatesLroClient { + OperationTemplatesLroClient { + api_version: self.api_version.clone(), + endpoint: self.endpoint.clone(), + pipeline: self.pipeline.clone(), + subscription_id: self.subscription_id.clone(), + } +} +``` + +To support subclient instantiation, the `azure_core` crate defines an attribute macro named `tracing::subclient` to support subclient instantiation. + +```rust +#[tracing::subclient] +pub fn get_operation_templates_lro_client(&self) -> OperationTemplatesLroClient { + OperationTemplatesLroClient { + api_version: self.api_version.clone(), + endpoint: self.endpoint.clone(), + pipeline: self.pipeline.clone(), + subscription_id: self.subscription_id.clone(), + } +} +``` + +This adds a clone of the parent client's `tracer` to the subclient - it functions similarly to `#[tracing::new]` but for subclient instantiation. diff --git a/sdk/core/azure_core/Cargo.toml b/sdk/core/azure_core/Cargo.toml index 4a9cc21e6a..ae635176e9 100644 --- a/sdk/core/azure_core/Cargo.toml +++ b/sdk/core/azure_core/Cargo.toml @@ -16,6 +16,7 @@ rust-version.workspace = true [dependencies] async-lock.workspace = true async-trait.workspace = true +azure_core_macros.workspace = true bytes.workspace = true futures.workspace = true hmac = { workspace = true, optional = true } @@ -37,6 +38,7 @@ typespec_client_core = { workspace = true, features = [ rustc_version.workspace = true [dev-dependencies] +azure_core_macros.path = "../azure_core_macros" azure_core_test.workspace = true azure_identity.workspace = true azure_security_keyvault_certificates.path = "../../keyvault/azure_security_keyvault_certificates" diff --git a/sdk/core/azure_core/src/http/options/mod.rs b/sdk/core/azure_core/src/http/options/mod.rs index a9593c3563..e337fa1873 100644 --- a/sdk/core/azure_core/src/http/options/mod.rs +++ b/sdk/core/azure_core/src/http/options/mod.rs @@ -1,8 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +mod request_instrumentation; mod user_agent; +pub use request_instrumentation::*; use std::sync::Arc; use typespec_client_core::http::policies::Policy; pub use typespec_client_core::http::{ @@ -27,6 +29,17 @@ pub struct ClientOptions { /// User-Agent telemetry options. pub user_agent: Option, + + /// Options for request instrumentation, such as distributed tracing. + /// + /// If not specified, defaults to no instrumentation. + /// + pub request_instrumentation: Option, +} + +pub(crate) struct CoreClientOptions { + pub(crate) user_agent: UserAgentOptions, + pub(crate) request_instrumentation: RequestInstrumentationOptions, } impl ClientOptions { @@ -35,7 +48,7 @@ impl ClientOptions { /// If instead we implemented [`Into`], we'd have to clone Azure-specific options instead of moving memory of [`Some`] values. pub(in crate::http) fn deconstruct( self, - ) -> (UserAgentOptions, typespec_client_core::http::ClientOptions) { + ) -> (CoreClientOptions, typespec_client_core::http::ClientOptions) { let options = typespec_client_core::http::ClientOptions { per_call_policies: self.per_call_policies, per_try_policies: self.per_try_policies, @@ -43,6 +56,12 @@ impl ClientOptions { transport: self.transport, }; - (self.user_agent.unwrap_or_default(), options) + ( + CoreClientOptions { + user_agent: self.user_agent.unwrap_or_default(), + request_instrumentation: self.request_instrumentation.unwrap_or_default(), + }, + options, + ) } } diff --git a/sdk/core/azure_core/src/http/options/request_instrumentation.rs b/sdk/core/azure_core/src/http/options/request_instrumentation.rs new file mode 100644 index 0000000000..0bd253224c --- /dev/null +++ b/sdk/core/azure_core/src/http/options/request_instrumentation.rs @@ -0,0 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use std::sync::Arc; + +/// Policy options to enable distributed tracing. +#[derive(Clone, Debug, Default)] +pub struct RequestInstrumentationOptions { + /// Set the tracer provider for distributed tracing. + pub tracer_provider: Option>, +} diff --git a/sdk/core/azure_core/src/http/pipeline.rs b/sdk/core/azure_core/src/http/pipeline.rs index ce69465631..6a1c777068 100644 --- a/sdk/core/azure_core/src/http/pipeline.rs +++ b/sdk/core/azure_core/src/http/pipeline.rs @@ -3,7 +3,9 @@ use super::policies::ClientRequestIdPolicy; use crate::http::{ - policies::{Policy, UserAgentPolicy}, + policies::{ + Policy, PublicApiInstrumentationPolicy, RequestInstrumentationPolicy, UserAgentPolicy, + }, ClientOptions, }; use std::{ @@ -48,12 +50,55 @@ impl Pipeline { per_call_policies: Vec>, per_try_policies: Vec>, ) -> Self { + let (core_client_options, options) = options.deconstruct(); + + let install_instrumentation_policies = core_client_options + .request_instrumentation + .tracer_provider + .is_some(); + + // Create a fallback tracer if no tracer provider is set. + // This is useful for service clients that have not yet been instrumented. + let tracer = if install_instrumentation_policies { + core_client_options + .request_instrumentation + .tracer_provider + .as_ref() + .map(|tracer_provider| { + tracer_provider.get_tracer(None, crate_name.unwrap_or("Unknown"), crate_version) + }) + } else { + None + }; + let mut per_call_policies = per_call_policies.clone(); push_unique(&mut per_call_policies, ClientRequestIdPolicy::default()); - - let (user_agent, options) = options.deconstruct(); - let telemetry_policy = UserAgentPolicy::new(crate_name, crate_version, &user_agent); - push_unique(&mut per_call_policies, telemetry_policy); + if install_instrumentation_policies { + let public_api_policy = PublicApiInstrumentationPolicy::new(tracer.clone()); + push_unique(&mut per_call_policies, public_api_policy); + } + + let user_agent_policy = + UserAgentPolicy::new(crate_name, crate_version, &core_client_options.user_agent); + push_unique(&mut per_call_policies, user_agent_policy); + + let mut per_try_policies = per_try_policies.clone(); + if install_instrumentation_policies { + // Note that the choice to use "None" as the namespace here + // is intentional. + // The `azure_namespace` parameter is used to populate the `az.namespace` + // span attribute, however that information is only known by the author of the + // client library, not the core library. + // It is also *not* a constant that can be derived from the crate information - + // it is a value that is determined from the list of resource providers + // listed [here](https://learn.microsoft.com/azure/azure-resource-manager/management/azure-services-resource-providers). + // + // This information can only come from the package owner. It doesn't make sense + // to burden all users of the azure_core pipeline with determining this + // information, so we use `None` here. + let request_instrumentation_policy = RequestInstrumentationPolicy::new(tracer); + push_unique(&mut per_try_policies, request_instrumentation_policy); + } Self(http::Pipeline::new( options, diff --git a/sdk/core/azure_core/src/http/policies/instrumentation/mod.rs b/sdk/core/azure_core/src/http/policies/instrumentation/mod.rs new file mode 100644 index 0000000000..2cd1f71672 --- /dev/null +++ b/sdk/core/azure_core/src/http/policies/instrumentation/mod.rs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +//! Instrumentation pipeline policies. + +mod public_api_instrumentation; +mod request_instrumentation; + +// Distributed tracing span attribute names. Defined in +// [OpenTelemetrySpans](https://github.com/open-telemetry/semantic-conventions/blob/main/docs/http/http-spans.md) +// and [Azure conventions for open telemetry spans](https://github.com/Azure/azure-sdk/blob/main/docs/tracing/distributed-tracing-conventions.md) +const AZ_NAMESPACE_ATTRIBUTE: &str = "az.namespace"; +const AZ_CLIENT_REQUEST_ID_ATTRIBUTE: &str = "az.client_request_id"; +const ERROR_TYPE_ATTRIBUTE: &str = "error.type"; +const AZ_SERVICE_REQUEST_ID_ATTRIBUTE: &str = "az.service_request.id"; +const HTTP_REQUEST_RESEND_COUNT_ATTRIBUTE: &str = "http.request.resend_count"; +const HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE: &str = "http.response.status_code"; +const HTTP_REQUEST_METHOD_ATTRIBUTE: &str = "http.request.method"; +const SERVER_ADDRESS_ATTRIBUTE: &str = "server.address"; +const SERVER_PORT_ATTRIBUTE: &str = "server.port"; +const URL_FULL_ATTRIBUTE: &str = "url.full"; + +pub(crate) use public_api_instrumentation::PublicApiInstrumentationPolicy; +pub use public_api_instrumentation::{create_public_api_span, PublicApiInstrumentationInformation}; +pub(crate) use request_instrumentation::*; diff --git a/sdk/core/azure_core/src/http/policies/instrumentation/public_api_instrumentation.rs b/sdk/core/azure_core/src/http/policies/instrumentation/public_api_instrumentation.rs new file mode 100644 index 0000000000..8a36d28a79 --- /dev/null +++ b/sdk/core/azure_core/src/http/policies/instrumentation/public_api_instrumentation.rs @@ -0,0 +1,699 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use super::{AZ_NAMESPACE_ATTRIBUTE, ERROR_TYPE_ATTRIBUTE}; +use crate::{ + http::{Context, Request}, + tracing::{Span, SpanKind, Tracer}, +}; +use ::tracing::trace; +use std::sync::Arc; +use typespec_client_core::{ + fmt::SafeDebug, + http::policies::{Policy, PolicyResult}, + tracing::Attribute, +}; + +/// Information about the public API being instrumented. +/// +/// This struct is used to pass information about the public API being instrumented +/// to the `PublicApiInstrumentationPolicy`. +/// +/// It contains the name of the API, which is used to create a span for distributed tracing +/// and any additional per-API attributes that might be needed for instrumentation. +/// +/// If the `PublicApiInstrumentationPolicy` policy detects a `PublicApiInstrumentationInformation` in the context, +/// it will create a span with the API name and any additional attributes. +#[derive(SafeDebug, Clone)] +pub struct PublicApiInstrumentationInformation { + /// The name of the API being instrumented. + /// + /// The API name should be in the form of `.`, where + /// `` is the name of the service client and `` is the name of the API. + /// + /// For example, if the service client is `MyClient` and the API is `my_api`, + /// the API name should be `MyClient.my_api`. + #[safe(true)] + pub api_name: &'static str, + + /// Additional attributes to be added to the span for this API. + /// + /// These attributes can provide additional information about the API being instrumented. + /// See [Library-specific attributes](https://github.com/Azure/azure-sdk/blob/main/docs/tracing/distributed-tracing-conventions.md#library-specific-attributes) + /// for more information. + /// + pub attributes: Vec, +} + +/// Sets distributed tracing information for HTTP requests. +#[derive(Clone, Debug)] +pub(crate) struct PublicApiInstrumentationPolicy { + tracer: Option>, +} + +impl PublicApiInstrumentationPolicy { + /// Creates a new `PublicApiInstrumentationPolicy`. + /// + /// + /// # Returns + /// A new instance of `PublicApiInstrumentationPolicy`. + /// + /// # Note + /// This policy will only create a tracer if a tracing provider is provided in the options. + /// + /// This policy will create a tracer that can be used to instrument HTTP requests. + /// However this tracer is only used when the client method is NOT instrumented. + /// A part of the client method instrumentation sets a client-specific tracer into the + /// request `[Context]` which will be used instead of the tracer from this policy. + /// + pub fn new(tracer: Option>) -> Self { + Self { tracer } + } +} + +/// Creates a span for the public API instrumentation policy. +/// +/// This function creates a span for the public API instrumentation policy based on the +/// public API information in the context. +/// +/// If no PublicApiInstrumentationInformation is provided, then this function will look in the `Context` +/// for a `PublicApiInstrumentationInformation` value, if it is not present, it will return `None`. +/// +/// # Arguments +/// - `ctx`: The context containing the public API information. +/// - `tracer`: An optional tracer to use for creating the span. +/// - `public_api_instrumentation`: Optional public API instrumentation information. +/// +/// # Returns +/// An optional span if the public API information is present and a tracer is available. +/// +/// If the context already has a span, it will return `None` to avoid nested spans. +/// If the context does not have a tracer it will use the value of the `tracer` argument. +/// If no tracer can be determined, it will return `None`. +/// +pub fn create_public_api_span( + ctx: &Context, + tracer: Option>, + public_api_instrumentation: Option, +) -> Option> { + // If there is a span in the context, we're a nested call, so we just want to forward the request. + if ctx.value::>().is_some() { + trace!( + "PublicApiPolicy: Nested call detected, forwarding request without instrumentation." + ); + return None; + } + + // We next confirm if the context has public API instrumentation information. + // Without a public API information, we skip instrumentation. + let info = public_api_instrumentation + .or_else(|| ctx.value::().cloned())?; + + // Get the tracer from either the context or the policy. + let tracer = match ctx.value::>() { + Some(t) => t.clone(), + None => tracer?, + }; + + // We now have public API information and a tracer. + // Calculate the span attributes based on the public API information and + // tracer. + let mut span_attributes = info + .attributes + .iter() + .map(|attr| { + // Convert the attribute to a span attribute. + Attribute { + key: attr.key.clone(), + value: attr.value.clone(), + } + }) + .collect::>(); + + if let Some(namespace) = tracer.namespace() { + // If the tracer has a namespace, we set it as an attribute. + span_attributes.push(Attribute { + key: AZ_NAMESPACE_ATTRIBUTE.into(), + value: namespace.into(), + }); + } + + // Create a span with the public API information and attributes. + Some(tracer.start_span(info.api_name, SpanKind::Internal, span_attributes)) +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl Policy for PublicApiInstrumentationPolicy { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + let Some(span) = create_public_api_span(ctx, self.tracer.clone(), None) else { + return next[0].send(ctx, request, &next[1..]).await; + }; + + // Now add the span to the context, so that it can be used by the next policies. + let ctx = ctx.clone().with_value(span.clone()); + + let result = next[0].send(&ctx, request, &next[1..]).await; + + // Don't bother setting attributes if the span isn't recording. + if span.is_recording() { + match &result { + Err(e) => { + // If the request failed, we set the error type on the span. + match e.kind() { + crate::error::ErrorKind::HttpResponse { status, .. } => { + span.set_attribute(ERROR_TYPE_ATTRIBUTE, status.to_string().into()); + + // 5xx status codes SHOULD set status to Error. + // The description should not be set because it can be inferred from "http.response.status_code". + if status.is_server_error() { + span.set_status(crate::tracing::SpanStatus::Error { + description: "".to_string(), + }); + } + } + _ => { + span.set_attribute(ERROR_TYPE_ATTRIBUTE, e.kind().to_string().into()); + span.set_status(crate::tracing::SpanStatus::Error { + description: e.kind().to_string(), + }); + } + } + } + Ok(response) => { + // 5xx status codes SHOULD set status to Error. + // The description should not be set because it can be inferred from "http.response.status_code". + if response.status().is_server_error() { + span.set_status(crate::tracing::SpanStatus::Error { + description: "".to_string(), + }); + } + if response.status().is_client_error() || response.status().is_server_error() { + span.set_attribute( + ERROR_TYPE_ATTRIBUTE, + response.status().to_string().into(), + ); + } + } + } + } + span.end(); + result + } +} + +#[cfg(test)] +mod tests { + // cspell: ignore traceparent + use super::*; + use crate::{ + http::{ + headers::Headers, + policies::{create_public_api_span, RequestInstrumentationPolicy, TransportPolicy}, + Method, RawResponse, StatusCode, TransportOptions, + }, + tracing::{SpanStatus, TracerProvider}, + Result, + }; + use azure_core_test::{ + http::MockHttpClient, + tracing::{ + check_instrumentation_result, ExpectedSpanInformation, ExpectedTracerInformation, + MockTracingProvider, + }, + }; + use futures::future::BoxFuture; + use std::sync::Arc; + + // Test just the public API instrumentation policy without request instrumentation. + async fn run_public_api_instrumentation_test( + api_information: Option, + create_tracer: bool, + add_tracer_to_context: bool, + request: &mut Request, + callback: C, + ) -> Arc + where + C: FnMut(&Request) -> BoxFuture<'_, Result> + Send + Sync + 'static, + { + // Add the public API information and tracer to the context so that it can be used by the policy. + let mock_tracer_provider = Arc::new(MockTracingProvider::new()); + + let tracer = if create_tracer { + Some(mock_tracer_provider.get_tracer( + add_tracer_to_context.then_some("test namespace"), + "test_crate", + Some("1.0.0"), + )) + } else { + None + }; + + let public_api_policy = { + let policy_tracer = tracer.clone(); + Arc::new(PublicApiInstrumentationPolicy::new(policy_tracer)) + }; + + let transport = TransportPolicy::new(TransportOptions::new(Arc::new(MockHttpClient::new( + callback, + )))); + + let next: Vec> = vec![Arc::new(transport)]; + + let mut ctx = Context::default(); + if let Some(t) = tracer { + if add_tracer_to_context { + // If we have a tracer, add it to the context. + ctx = ctx.with_value(t.clone()); + } + } + + if api_information.is_some() { + // If we have public API information, add it to the context. + ctx = ctx.with_value(api_information.unwrap()); + } + let _result = public_api_policy.send(&ctx, request, &next).await; + + mock_tracer_provider + } + + async fn run_public_api_instrumentation_test_with_request_instrumentation( + api_name: Option<&'static str>, + namespace: Option<&'static str>, + crate_name: Option<&'static str>, + version: Option<&'static str>, + request: &mut Request, + callback: C, + ) -> Arc + where + C: FnMut(&Request) -> BoxFuture<'_, Result> + Send + Sync + 'static, + { + let mock_tracer_provider = Arc::new(MockTracingProvider::new()); + let mock_tracer = + mock_tracer_provider.get_tracer(namespace, crate_name.unwrap_or("unknown"), version); + + let public_api_policy = Arc::new(PublicApiInstrumentationPolicy::new(Some( + mock_tracer.clone(), + ))); + + let transport = TransportPolicy::new(TransportOptions::new(Arc::new(MockHttpClient::new( + callback, + )))); + + let request_instrumentation_policy = + RequestInstrumentationPolicy::new(Some(mock_tracer.clone())); + + let next: Vec> = vec![ + Arc::new(request_instrumentation_policy), + Arc::new(transport), + ]; + let public_api_information = PublicApiInstrumentationInformation { + api_name: api_name.unwrap_or("unknown"), + attributes: vec![Attribute { + key: "az.fake_attribute".into(), + value: "attribute value".into(), + }], + }; + + // Add the public API information and tracer to the context so that it can be used by the policy. + let ctx = Context::default() + .with_value(public_api_information) + .with_value(mock_tracer.clone()); + let _result = public_api_policy.send(&ctx, request, &next).await; + + mock_tracer_provider + } + + // Tests for the create_public_api_span function. + #[test] + fn create_public_api_span_tests() { + let tracer = + Arc::new(MockTracingProvider::new()).get_tracer(Some("test"), "test", Some("1.0.0")); + + // Test when context has no PublicApiInstrumentationInformation + { + let ctx = Context::default(); + let span = create_public_api_span(&ctx, Some(tracer.clone()), None); + assert!(span.is_none(), "Should return None when no API info exists"); + } + } + + // Test when context already has a span + #[test] + fn create_public_api_span_tests_context_has_span() { + let tracer = + Arc::new(MockTracingProvider::new()).get_tracer(Some("test"), "test", Some("1.0.0")); + { + let existing_span = tracer.start_span("existing", SpanKind::Internal, vec![]); + let ctx = Context::default().with_value(existing_span.clone()); + let span = create_public_api_span(&ctx, Some(tracer.clone()), None); + assert!( + span.is_none(), + "Should return None when context already has a span" + ); + } + } + + // Tests for the create_public_api_span function. + #[test] + fn create_public_api_span_tests_public_api_information_from_param() { + let tracer = + Arc::new(MockTracingProvider::new()).get_tracer(Some("test"), "test", Some("1.0.0")); + + // Test when context has no PublicApiInstrumentationInformation + { + let ctx = Context::default(); + let span = create_public_api_span( + &ctx, + Some(tracer.clone()), + Some(PublicApiInstrumentationInformation { + api_name: "TestClient.test_api", + attributes: vec![], + }), + ); + assert!( + span.is_some(), + "Should return Some when info exists as param" + ); + } + } + + // Test with API info but no tracer + #[test] + fn create_public_api_span_tests_public_api_info_no_tracer() { + { + let api_info = PublicApiInstrumentationInformation { + api_name: "TestClient.test_api", + attributes: vec![], + }; + let ctx = Context::default().with_value(api_info); + let span = create_public_api_span(&ctx, None, None); + assert!( + span.is_none(), + "Should return None when no tracer is available" + ); + } + } + // Test with API info and tracer from context + #[test] + fn create_public_api_span_tests_api_info_and_tracer_from_context() { + let tracer = + Arc::new(MockTracingProvider::new()).get_tracer(Some("test"), "test", Some("1.0.0")); + { + let api_info = PublicApiInstrumentationInformation { + api_name: "TestClient.test_api", + attributes: vec![], + }; + let ctx = Context::default() + .with_value(api_info) + .with_value(tracer.clone()); + let span = create_public_api_span(&ctx, None, None); + assert!( + span.is_some(), + "Should create span when API info and tracer are available" + ); + } + } + // Test with API info, tracer from parameter, and attributes + #[test] + fn create_public_api_span_tests_tracer_from_parameter() { + let tracer = + Arc::new(MockTracingProvider::new()).get_tracer(Some("test"), "test", Some("1.0.0")); + { + let api_info = PublicApiInstrumentationInformation { + api_name: "TestClient.test_api", + attributes: vec![Attribute { + key: "test.attribute".into(), + value: "test_value".into(), + }], + }; + let ctx = Context::default().with_value(api_info); + let span = create_public_api_span(&ctx, Some(tracer.clone()), None); + assert!(span.is_some(), "Should create span with attributes"); + } + } + + #[tokio::test] + async fn public_api_instrumentation_no_public_api_info() { + let url = "http://example.com/path"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + + let mock_tracer = run_public_api_instrumentation_test( + None, // No public API information. + true, // Create tracer. + true, + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Get); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer, + vec![ExpectedTracerInformation { + name: "test_crate", + version: Some("1.0.0"), + namespace: Some("test namespace"), + spans: vec![], + }], + ); + } + + #[tokio::test] + async fn public_api_instrumentation_no_tracer() { + let url = "http://example.com/path"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + + let mock_tracer = run_public_api_instrumentation_test( + Some(PublicApiInstrumentationInformation { + api_name: "MyClient.MyApi", + attributes: vec![], + }), + false, // Create tracer. + false, // Add tracer to context. + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Get); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + // No tracer should be created, so we expect no spans. + check_instrumentation_result(mock_tracer, vec![]); + } + + #[tokio::test] + async fn public_api_instrumentation_tracer_not_in_context() { + let url = "http://example.com/path"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + + let mock_tracer = run_public_api_instrumentation_test( + Some(PublicApiInstrumentationInformation { + api_name: "MyClient.MyApi", + attributes: vec![], + }), + true, // Create tracer. + false, // Add tracer to context. + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Get); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer, + vec![ExpectedTracerInformation { + name: "test_crate", + version: Some("1.0.0"), + namespace: None, + spans: vec![ExpectedSpanInformation { + span_name: "MyClient.MyApi", + status: SpanStatus::Unset, + kind: SpanKind::Internal, + attributes: vec![], + }], + }], + ) + } + + #[tokio::test] + async fn simple_public_api_instrumentation_policy() { + let url = "http://example.com/path"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + + let mock_tracer = run_public_api_instrumentation_test( + Some(PublicApiInstrumentationInformation { + api_name: "MyClient.MyApi", + attributes: vec![], + }), + true, // Create tracer. + true, + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Get); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer, + vec![ExpectedTracerInformation { + name: "test_crate", + version: Some("1.0.0"), + namespace: Some("test namespace"), + spans: vec![ExpectedSpanInformation { + span_name: "MyClient.MyApi", + status: SpanStatus::Unset, + kind: SpanKind::Internal, + attributes: vec![(AZ_NAMESPACE_ATTRIBUTE, "test namespace".into())], + }], + }], + ); + } + + #[tokio::test] + async fn public_api_instrumentation_policy_with_error() { + let url = "http://example.com/path"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + + let mock_tracer = run_public_api_instrumentation_test( + Some(PublicApiInstrumentationInformation { + api_name: "MyClient.MyApi", + attributes: vec![], + }), + true, + true, + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Get); + Ok(RawResponse::from_bytes( + StatusCode::InternalServerError, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer.clone(), + vec![ExpectedTracerInformation { + name: "test_crate", + version: Some("1.0.0"), + namespace: Some("test namespace"), + spans: vec![ExpectedSpanInformation { + span_name: "MyClient.MyApi", + status: SpanStatus::Error { + description: "".to_string(), + }, + kind: SpanKind::Internal, + attributes: vec![ + (AZ_NAMESPACE_ATTRIBUTE, "test namespace".into()), + (ERROR_TYPE_ATTRIBUTE, "500".into()), + ], + }], + }], + ); + } + + #[tokio::test] + async fn public_api_instrumentation_policy_with_request_instrumentation() { + let url = "http://example.com/path_with_request"; + let mut request = Request::new(url.parse().unwrap(), Method::Put); + + let mock_tracer = run_public_api_instrumentation_test_with_request_instrumentation( + Some("MyClient.MyApi"), + Some("test.namespace"), + Some("test_crate"), + Some("1.0.0"), + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Put); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer.clone(), + vec![ExpectedTracerInformation { + name: "test_crate", + version: Some("1.0.0"), + namespace: Some("test.namespace"), + spans: vec![ + ExpectedSpanInformation { + span_name: "MyClient.MyApi", + status: SpanStatus::Unset, + kind: SpanKind::Internal, + attributes: vec![ + (AZ_NAMESPACE_ATTRIBUTE, "test.namespace".into()), + ("az.fake_attribute", "attribute value".into()), + ], + }, + ExpectedSpanInformation { + span_name: "PUT", + status: SpanStatus::Unset, + kind: SpanKind::Client, + attributes: vec![ + (AZ_NAMESPACE_ATTRIBUTE, "test.namespace".into()), + ("http.request.method", "PUT".into()), + ("url.full", "http://example.com/path_with_request".into()), + ("server.address", "example.com".into()), + ("server.port", 80.into()), + ("http.response.status_code", 200.into()), + ], + }, + ], + }], + ); + } +} diff --git a/sdk/core/azure_core/src/http/policies/instrumentation/request_instrumentation.rs b/sdk/core/azure_core/src/http/policies/instrumentation/request_instrumentation.rs new file mode 100644 index 0000000000..034b9a18d7 --- /dev/null +++ b/sdk/core/azure_core/src/http/policies/instrumentation/request_instrumentation.rs @@ -0,0 +1,477 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use super::{ + AZ_CLIENT_REQUEST_ID_ATTRIBUTE, AZ_NAMESPACE_ATTRIBUTE, AZ_SERVICE_REQUEST_ID_ATTRIBUTE, + ERROR_TYPE_ATTRIBUTE, HTTP_REQUEST_METHOD_ATTRIBUTE, HTTP_REQUEST_RESEND_COUNT_ATTRIBUTE, + HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE, SERVER_ADDRESS_ATTRIBUTE, SERVER_PORT_ATTRIBUTE, + URL_FULL_ATTRIBUTE, +}; +use crate::{ + http::{headers, Context, Request}, + tracing::{Span, SpanKind}, +}; +use std::sync::Arc; +use typespec_client_core::{ + http::policies::{Policy, PolicyResult, RetryPolicyCount}, + tracing::Attribute, +}; + +/// Sets distributed tracing information for HTTP requests. +#[derive(Clone, Debug)] +pub(crate) struct RequestInstrumentationPolicy { + tracer: Option>, +} + +impl RequestInstrumentationPolicy { + /// Creates a new `RequestInstrumentationPolicy`. + /// + /// # Arguments + /// - `tracer`: Pre-configured tracer to use for instrumentation. + /// + /// # Returns + /// A new instance of `RequestInstrumentationPolicy`. + /// + /// # Note + /// + /// The tracer provided is a "fallback" tracer which is used if the `ctx` parameter + /// to the `send` method does not contain a tracer. + /// + pub fn new(tracer: Option>) -> Self { + Self { tracer } + } +} + +#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)] +impl Policy for RequestInstrumentationPolicy { + async fn send( + &self, + ctx: &Context, + request: &mut Request, + next: &[Arc], + ) -> PolicyResult { + // If the context has a tracer (which happens when called from an instrumented method), + // we prefer the tracer from the context. + // Otherwise, we use the tracer from the policy itself. + // This allows for flexibility in using different tracers in different contexts. + + // We use `.or_else` here instead of `.or` because `.or` eagerly evaluates the right-hand side, + // which can lead to unnecessary overhead if the tracer is not needed. + #[allow(clippy::unnecessary_lazy_evaluations)] + let tracer = ctx + .value::>() + .or_else(|| self.tracer.as_ref()); + + let Some(tracer) = tracer else { + return next[0].send(ctx, request, &next[1..]).await; + }; + + let mut span_attributes = vec![Attribute { + key: HTTP_REQUEST_METHOD_ATTRIBUTE.into(), + value: request.method().to_string().into(), + }]; + + if let Some(namespace) = tracer.namespace() { + // If the tracer has a namespace, we set it as an attribute. + span_attributes.push(Attribute { + key: AZ_NAMESPACE_ATTRIBUTE.into(), + value: namespace.into(), + }); + } + + // OpenTelemetry requires that we sanitize the URL if it contains a username or password. + // Since a valid Azure SDK endpoint should never contain a username or password, if + // the url contains a username or password, we simply omit the URL_FULL_ATTRIBUTE. + if request.url().username().is_empty() && request.url().password().is_none() { + span_attributes.push(Attribute { + key: URL_FULL_ATTRIBUTE.into(), + value: request.url().to_string().into(), + }); + } + + if let Some(host) = request.url().host() { + span_attributes.push(Attribute { + key: SERVER_ADDRESS_ATTRIBUTE.into(), + value: host.to_string().into(), + }); + } + if let Some(port) = request.url().port_or_known_default() { + span_attributes.push(Attribute { + key: SERVER_PORT_ATTRIBUTE.into(), + value: port.into(), + }); + } + // Get the method as a string to avoid lifetime issues + let method_str = request.method().as_str(); + let span = if let Some(parent_span) = ctx.value::>() { + // If a parent span exists, start a new span with the parent. + tracer.start_span_with_parent( + method_str, + SpanKind::Client, + span_attributes, + parent_span.clone(), + ) + } else { + // If no parent span exists, start a new span with the "current" span (if any). + // It is up to the tracer implementation to determine what "current" means. + tracer.start_span(method_str, SpanKind::Client, span_attributes) + }; + + if span.is_recording() { + if let Some(client_request_id) = request + .headers() + .get_optional_str(&headers::CLIENT_REQUEST_ID) + { + span.set_attribute(AZ_CLIENT_REQUEST_ID_ATTRIBUTE, client_request_id.into()); + } + + if let Some(service_request_id) = + request.headers().get_optional_str(&headers::REQUEST_ID) + { + span.set_attribute(AZ_SERVICE_REQUEST_ID_ATTRIBUTE, service_request_id.into()); + } + + if let Some(retry_count) = ctx.value::() { + if **retry_count > 0 { + span.set_attribute(HTTP_REQUEST_RESEND_COUNT_ATTRIBUTE, (**retry_count).into()); + } + } + } + + // Propagate the headers for distributed tracing into the request. + span.propagate_headers(request); + + let result = next[0].send(ctx, request, &next[1..]).await; + + if span.is_recording() { + if let Some(err) = result.as_ref().err() { + // If the request failed, set an error type attribute. + span.set_attribute(ERROR_TYPE_ATTRIBUTE, err.kind().to_string().into()); + } + if let Ok(response) = result.as_ref() { + // If the request was successful, set the HTTP response status code. + span.set_attribute( + HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE, + u16::from(response.status()).into(), + ); + + if response.status().is_server_error() || response.status().is_client_error() { + // If the response status indicates an error, set the span status to error. + // Since the reason can be inferred from the status code, description is left empty. + span.set_status(crate::tracing::SpanStatus::Error { + description: "".to_string(), + }); + // Set the error type attribute for all HTTP 4XX or 5XX errors. + span.set_attribute(ERROR_TYPE_ATTRIBUTE, response.status().to_string().into()); + } + } + } + span.end(); + return result; + } +} +#[cfg(test)] +pub(crate) mod tests { + // cspell: ignore traceparent + use super::*; + use crate::{ + http::{ + headers::Headers, policies::TransportPolicy, Method, RawResponse, StatusCode, + TransportOptions, + }, + tracing::{AttributeValue, SpanStatus, TracerProvider}, + Result, + }; + use azure_core_test::{ + http::MockHttpClient, + tracing::{ + check_instrumentation_result, ExpectedSpanInformation, ExpectedTracerInformation, + MockTracingProvider, + }, + }; + use futures::future::BoxFuture; + use std::sync::Arc; + use typespec_client_core::http::headers::HeaderName; + + async fn run_instrumentation_test( + test_namespace: Option<&'static str>, + crate_name: Option<&'static str>, + version: Option<&'static str>, + request: &mut Request, + callback: C, + ) -> Arc + where + C: FnMut(&Request) -> BoxFuture<'_, Result> + Send + Sync + 'static, + { + let mock_tracer_provider = Arc::new(MockTracingProvider::new()); + let tracer = mock_tracer_provider.get_tracer( + test_namespace, + crate_name.unwrap_or("unknown"), + version, + ); + let policy = Arc::new(RequestInstrumentationPolicy::new(Some(tracer.clone()))); + + let transport = TransportPolicy::new(TransportOptions::new(Arc::new(MockHttpClient::new( + callback, + )))); + + let ctx = Context::default(); + let next: Vec> = vec![Arc::new(transport)]; + let _result = policy.send(&ctx, request, &next).await; + + mock_tracer_provider + } + + #[tokio::test] + async fn simple_instrumentation_policy() { + let url = "http://example.com/path"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + + let mock_tracer = run_instrumentation_test( + Some("test namespace"), + Some("test_crate"), + Some("1.0.0"), + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Get); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer, + vec![ExpectedTracerInformation { + namespace: Some("test namespace"), + name: "test_crate", + version: Some("1.0.0"), + spans: vec![ExpectedSpanInformation { + span_name: "GET", + status: SpanStatus::Unset, + kind: SpanKind::Client, + attributes: vec![ + ( + AZ_NAMESPACE_ATTRIBUTE, + AttributeValue::from("test namespace"), + ), + ( + HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE, + AttributeValue::from(200), + ), + (HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("GET")), + ( + SERVER_ADDRESS_ATTRIBUTE, + AttributeValue::from("example.com"), + ), + (SERVER_PORT_ATTRIBUTE, AttributeValue::from(80)), + ( + URL_FULL_ATTRIBUTE, + AttributeValue::from("http://example.com/path"), + ), + ], + }], + }], + ); + } + + #[test] + fn test_request_instrumentation_policy_creation() { + let policy = RequestInstrumentationPolicy::new(None); + assert!(policy.tracer.is_none()); + + let mock_tracer_provider = Arc::new(MockTracingProvider::new()); + let tracer = + mock_tracer_provider.get_tracer(Some("test namespace"), "test_crate", Some("1.0.0")); + let policy_with_tracer = RequestInstrumentationPolicy::new(Some(tracer)); + assert!(policy_with_tracer.tracer.is_some()); + } + + #[test] + fn test_request_instrumentation_policy_without_tracer() { + let policy = RequestInstrumentationPolicy::new(None); + assert!(policy.tracer.is_none()); + } + + #[tokio::test] + async fn client_request_id() { + let url = "https://example.com/client_request_id"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + request.insert_header(headers::CLIENT_REQUEST_ID, "test-client-request-id"); + + let mock_tracer = run_instrumentation_test( + None, + Some("test_crate"), + Some("1.0.0"), + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("example.com")); + assert_eq!(req.method(), Method::Get); + assert_eq!( + req.headers() + .get_optional_str(&HeaderName::from_static("traceparent")), + Some("00---01") + ); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer, + vec![ExpectedTracerInformation { + namespace: None, + name: "test_crate", + version: Some("1.0.0"), + spans: vec![ExpectedSpanInformation { + span_name: "GET", + status: SpanStatus::Unset, + kind: SpanKind::Client, + attributes: vec![ + ( + AZ_CLIENT_REQUEST_ID_ATTRIBUTE, + AttributeValue::from("test-client-request-id"), + ), + ( + HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE, + AttributeValue::from(200), + ), + (HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("GET")), + ( + SERVER_ADDRESS_ATTRIBUTE, + AttributeValue::from("example.com"), + ), + (SERVER_PORT_ATTRIBUTE, AttributeValue::from(443)), + ( + URL_FULL_ATTRIBUTE, + AttributeValue::from("https://example.com/client_request_id"), + ), + ], + }], + }], + ); + } + + #[tokio::test] + async fn test_url_with_password() { + let url = "https://user:password@host:8080/path?query=value#fragment"; + let mut request = Request::new(url.parse().unwrap(), Method::Get); + + let mock_tracer_provider = + run_instrumentation_test(None, None, None, &mut request, |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("host")); + assert_eq!(req.method(), Method::Get); + Ok(RawResponse::from_bytes( + StatusCode::Ok, + Headers::new(), + vec![], + )) + }) + }) + .await; + // Because the URL contains a username and password, we do not set the URL_FULL_ATTRIBUTE. + check_instrumentation_result( + mock_tracer_provider, + vec![ExpectedTracerInformation { + namespace: None, + name: "unknown", + version: None, + spans: vec![ExpectedSpanInformation { + span_name: "GET", + status: SpanStatus::Unset, + kind: SpanKind::Client, + attributes: vec![ + ( + HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE, + AttributeValue::from(200), + ), + (HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("GET")), + (SERVER_ADDRESS_ATTRIBUTE, AttributeValue::from("host")), + (SERVER_PORT_ATTRIBUTE, AttributeValue::from(8080)), + ], + }], + }], + ); + } + + #[tokio::test] + async fn request_failed() { + let url = "https://microsoft.com/request_failed.htm"; + let mut request = Request::new(url.parse().unwrap(), Method::Put); + request.insert_header(headers::REQUEST_ID, "test-service-request-id"); + + let mock_tracer = run_instrumentation_test( + Some("test namespace"), + Some("test_crate"), + Some("1.0.0"), + &mut request, + |req| { + Box::pin(async move { + assert_eq!(req.url().host_str(), Some("microsoft.com")); + assert_eq!(req.method(), Method::Put); + Ok(RawResponse::from_bytes( + StatusCode::NotFound, + Headers::new(), + vec![], + )) + }) + }, + ) + .await; + + check_instrumentation_result( + mock_tracer, + vec![ExpectedTracerInformation { + namespace: Some("test namespace"), + name: "test_crate", + version: Some("1.0.0"), + spans: vec![ExpectedSpanInformation { + span_name: "PUT", + status: SpanStatus::Error { + description: "".to_string(), + }, + kind: SpanKind::Client, + attributes: vec![ + (ERROR_TYPE_ATTRIBUTE, AttributeValue::from("404")), + ( + AZ_SERVICE_REQUEST_ID_ATTRIBUTE, + AttributeValue::from("test-service-request-id"), + ), + ( + AZ_NAMESPACE_ATTRIBUTE, + AttributeValue::from("test namespace"), + ), + ( + HTTP_RESPONSE_STATUS_CODE_ATTRIBUTE, + AttributeValue::from(404), + ), + (HTTP_REQUEST_METHOD_ATTRIBUTE, AttributeValue::from("PUT")), + ( + SERVER_ADDRESS_ATTRIBUTE, + AttributeValue::from("microsoft.com"), + ), + (SERVER_PORT_ATTRIBUTE, AttributeValue::from(443)), + ( + URL_FULL_ATTRIBUTE, + AttributeValue::from("https://microsoft.com/request_failed.htm"), + ), + ], + }], + }], + ); + } +} diff --git a/sdk/core/azure_core/src/http/policies/mod.rs b/sdk/core/azure_core/src/http/policies/mod.rs index 81c8e3769a..8625b69946 100644 --- a/sdk/core/azure_core/src/http/policies/mod.rs +++ b/sdk/core/azure_core/src/http/policies/mod.rs @@ -5,9 +5,11 @@ mod bearer_token_policy; mod client_request_id; +mod instrumentation; mod user_agent; pub use bearer_token_policy::BearerTokenCredentialPolicy; pub use client_request_id::*; +pub use instrumentation::*; pub use typespec_client_core::http::policies::*; pub use user_agent::*; diff --git a/sdk/core/azure_core/src/lib.rs b/sdk/core/azure_core/src/lib.rs index 97182941cb..6d7f5fd2de 100644 --- a/sdk/core/azure_core/src/lib.rs +++ b/sdk/core/azure_core/src/lib.rs @@ -27,7 +27,10 @@ pub use typespec_client_core::{ fmt, json, sleep, stream, time, Bytes, Uuid, }; +/// Abstractions for distributed tracing and telemetry. pub mod tracing { + pub use crate::http::policies::PublicApiInstrumentationInformation; + pub use azure_core_macros::*; pub use typespec_client_core::tracing::*; } diff --git a/sdk/core/azure_core_macros/Cargo.toml b/sdk/core/azure_core_macros/Cargo.toml new file mode 100644 index 0000000000..382853ea21 --- /dev/null +++ b/sdk/core/azure_core_macros/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "azure_core_macros" +version = "0.1.0" +description = "Procedural macros for client libraries built on azure_core." +readme = "README.md" +authors.workspace = true +license.workspace = true +repository.workspace = true +homepage = "https://github.com/azure/azure-sdk-for-rust" +documentation = "https://docs.rs/azure_core" +keywords = ["azure", "cloud", "iot", "rest", "sdk"] +categories = ["development-tools"] +edition.workspace = true +rust-version.workspace = true + +[lib] +proc-macro = true + +[dependencies] +proc-macro2.workspace = true +quote.workspace = true +syn.workspace = true +typespec_client_core = { workspace = true, features = ["http", "json"] } +tracing.workspace = true + +[dev-dependencies] +tokio.workspace = true +tracing-subscriber = { workspace = true, features = ["env-filter", "fmt"] } diff --git a/sdk/core/azure_core_macros/README.md b/sdk/core/azure_core_macros/README.md new file mode 100644 index 0000000000..cbe67c0b26 --- /dev/null +++ b/sdk/core/azure_core_macros/README.md @@ -0,0 +1,3 @@ +# Azure client library macros + +Macros for client libraries built on `azure_core`. diff --git a/sdk/core/azure_core_macros/src/lib.rs b/sdk/core/azure_core_macros/src/lib.rs new file mode 100644 index 0000000000..b66104c637 --- /dev/null +++ b/sdk/core/azure_core_macros/src/lib.rs @@ -0,0 +1,56 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#![doc = include_str!("../README.md")] + +mod tracing; +mod tracing_client; +mod tracing_function; +mod tracing_new; +mod tracing_subclient; + +use proc_macro::TokenStream; + +/// Attribute client struct declarations to enable distributed tracing. +/// +/// To declare a client that will be traced, you should use the `#[tracing::client]` attribute +/// exported from azure_core. +/// +#[proc_macro_attribute] +pub fn client(attr: TokenStream, item: TokenStream) -> TokenStream { + tracing_client::parse_client(attr.into(), item.into()) + .map_or_else(|e| e.into_compile_error().into(), |v| v.into()) +} + +/// Attribute client struct instantiation to enable distributed tracing. +/// +/// To enable tracing for a client instantiation, you should use the `#[tracing::new]` attribute +/// exported from azure_core. +/// +/// This macro will automatically instrument the client instantiation with tracing information. +/// It will also ensure that the client is created with the necessary tracing context. +/// +/// The `#[tracing::new]` attribute takes a single argument, which is a string +/// representing the Azure Namespace name for the service being traced. +/// +/// The list of Azure Namespaces can be found [on this page](https://learn.microsoft.com/azure/azure-resource-manager/management/azure-services-resource-providers) +/// +#[proc_macro_attribute] +pub fn new(attr: TokenStream, item: TokenStream) -> TokenStream { + tracing_new::parse_new(attr.into(), item.into()) + .map_or_else(|e| e.into_compile_error().into(), |v| v.into()) +} + +#[proc_macro_attribute] +pub fn subclient(attr: TokenStream, item: TokenStream) -> TokenStream { + tracing_subclient::parse_subclient(attr.into(), item.into()) + .map_or_else(|e| e.into_compile_error().into(), |v| v.into()) +} + +/// Attribute client public APIs to enable distributed tracing. +/// +#[proc_macro_attribute] +pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream { + tracing_function::parse_function(attr.into(), item.into()) + .map_or_else(|e| e.into_compile_error().into(), |v| v.into()) +} diff --git a/sdk/core/azure_core_macros/src/tracing.rs b/sdk/core/azure_core_macros/src/tracing.rs new file mode 100644 index 0000000000..8f07bd4c76 --- /dev/null +++ b/sdk/core/azure_core_macros/src/tracing.rs @@ -0,0 +1,77 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#[cfg(test)] +pub(crate) mod tests { + use ::tracing::{error, trace}; + use proc_macro2::{TokenStream, TokenTree}; + static INIT_LOGGING: std::sync::Once = std::sync::Once::new(); + + pub(crate) fn setup_tracing() { + INIT_LOGGING.call_once(|| { + println!("Setting up test logger..."); + + use tracing_subscriber::{fmt::format::FmtSpan, EnvFilter}; + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) + .with_ansi(std::env::var("NO_COLOR").map_or(true, |v| v.is_empty())) + .with_writer(std::io::stderr) + .init(); + }); + } + + // cspell: ignore punct + + pub(crate) fn compare_token_tree(token: &TokenTree, expected_token: &TokenTree) -> bool { + match (token, expected_token) { + (TokenTree::Group(group), TokenTree::Group(expected_group)) => { + compare_token_stream(group.stream(), expected_group.stream()) + } + + (TokenTree::Ident(ident), TokenTree::Ident(expected_ident)) => { + *expected_ident == *ident + } + (TokenTree::Punct(punct), TokenTree::Punct(expected_punct)) => { + punct.as_char() == expected_punct.as_char() + } + (TokenTree::Literal(literal), TokenTree::Literal(expected_literal)) => { + literal.to_string() == expected_literal.to_string() + } + _ => { + error!("Unexpected token: {expected_token:?}"); + false + } + } + } + + pub(crate) fn compare_token_stream(actual: TokenStream, expected: TokenStream) -> bool { + let actual_tokens = Vec::from_iter(actual); + let expected_tokens = Vec::from_iter(expected); + + if actual_tokens.len() != expected_tokens.len() { + error!( + "Token lengths do not match: actual: {} != expected: {}", + actual_tokens.len(), + expected_tokens.len() + ); + for (i, actual) in actual_tokens.iter().enumerate() { + trace!("Actual token at index {i}: {actual:?}"); + } + + for (i, expected) in expected_tokens.iter().enumerate() { + trace!("Expected token at index {i}: {expected:?}"); + } + return false; + } + + for (actual, expected) in actual_tokens.iter().zip(expected_tokens.iter()) { + let equal = compare_token_tree(actual, expected); + if !equal { + error!("Tokens do not match: {actual:?} != {expected:?}"); + return false; + } + } + true + } +} diff --git a/sdk/core/azure_core_macros/src/tracing_client.rs b/sdk/core/azure_core_macros/src/tracing_client.rs new file mode 100644 index 0000000000..c401d16419 --- /dev/null +++ b/sdk/core/azure_core_macros/src/tracing_client.rs @@ -0,0 +1,99 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use proc_macro2::TokenStream; +use quote::quote; +use syn::{spanned::Spanned, ItemStruct, Result}; + +const INVALID_SERVICE_CLIENT_MESSAGE: &str = + "client attribute must be applied to a public struct with no generic type parameters"; + +/// Parse the token stream for an Azure Service client declaration. +/// +/// An Azure Service client is a public struct that represents a client for an Azure service. +/// +/// This macro will ensure that the struct is public and has a `tracer` field of type `Option`. +/// +pub fn parse_client(_attr: TokenStream, item: TokenStream) -> Result { + if !is_client_declaration(&item) { + return Err(syn::Error::new(item.span(), INVALID_SERVICE_CLIENT_MESSAGE)); + } + + let ItemStruct { + vis, ident, fields, .. + } = syn::parse2(item.clone())?; + + let fields = fields.iter(); + Ok(quote! { + #vis + struct #ident { + #(#fields),*, + pub(crate) tracer: Option>, + } + }) +} + +/// Returns true if the item at the head of the token stream is a valid service client declaration. +fn is_client_declaration(item: &TokenStream) -> bool { + let ItemStruct { vis, generics, .. } = match syn::parse2(item.clone()) { + Ok(struct_item) => struct_item, + Err(_) => return false, + }; + + if !generics.params.is_empty() { + // Service clients must not have generic type parameters. + return false; + } + + // Service clients must be public structs. + if !matches!(vis, syn::Visibility::Public(_)) { + return false; + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + // cspell: ignore punct + #[test] + fn parse_service_client() { + let attr = TokenStream::new(); + let item = quote! { + pub struct ServiceClient { + name: &'static str, + endpoint: Url, + } + }; + let actual = parse_client(attr, item).expect("Failed to parse client declaration"); + let expected = quote! { + pub struct ServiceClient { + name: &'static str, + endpoint: Url, + pub(crate) tracer: Option>, + } + }; + // println!("Parsed tokens: {:?}", tokens); + // println!("Expected tokens: {:?}", expected); + + assert!( + crate::tracing::tests::compare_token_stream(actual, expected), + "Parsed tokens do not match expected tokens" + ); + } + + #[test] + fn parse_not_service_client() { + let attr = TokenStream::new(); + let item = quote! { + fn NotServiceClient(&self, name: &'static str) -> Result<(), Box> { + Ok(()) + } + }; + assert!( + parse_client(attr, item).is_err(), + "Expected error for non-client declaration" + ); + } +} diff --git a/sdk/core/azure_core_macros/src/tracing_function.rs b/sdk/core/azure_core_macros/src/tracing_function.rs new file mode 100644 index 0000000000..cb29d97426 --- /dev/null +++ b/sdk/core/azure_core_macros/src/tracing_function.rs @@ -0,0 +1,555 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{parse::Parse, spanned::Spanned, ItemFn, Member, Result, Token}; + +const INVALID_PUBLIC_FUNCTION_MESSAGE: &str = + "function attribute must be applied to a public function returning a Result"; + +// cspell: ignore asyncness + +/// Parse the token stream for an Azure Service client "new" declaration. +/// +/// An Azure Service client "new" declaration is a public function whose name starts with +/// `new` and returns either a new client instance or an error. +/// +/// This macro will ensure that the fn is public and returns one of the following: +/// 1) `Self` +/// 1) `Arc` +/// 1) `Result` +/// 1) `Result, E>` +/// +pub fn parse_function(attr: TokenStream, item: TokenStream) -> Result { + if !is_function_declaration(&item) { + println!("Not a function declaration: {item}"); + return Err(syn::Error::new( + item.span(), + INVALID_PUBLIC_FUNCTION_MESSAGE, + )); + } + + let function_name_and_attributes: FunctionNameAndAttributes = syn::parse2(attr)?; + + let api_name = function_name_and_attributes.function_name; + + let ItemFn { + attrs, + vis, + sig, + block, + } = syn::parse2(item)?; + + let attributes: TokenStream = if function_name_and_attributes.arguments.is_empty() { + quote! {Vec::new()} + } else { + let attribute_vec = function_name_and_attributes + .arguments + .into_iter() + .map(|(name, value)| { + quote! { + ::typespec_client_core::tracing::Attribute{key: #name.into(), value: #value.into()} + } + }) + .collect::>(); + quote! { vec![#(#attribute_vec),*] } + }; + + let preamble = quote! { + let options = { + let mut options = options.unwrap_or_default(); + + let public_api_info = azure_core::tracing::PublicApiInstrumentationInformation { + api_name: #api_name, + attributes: #attributes, + }; + // Add the span to the tracer. + let mut ctx = options.method_options.context.with_value(public_api_info); + // If the service has a tracer, we add it to the context. + if let Some(tracer) = &self.tracer { + ctx = ctx.with_value(tracer.clone()); + } + options.method_options.context = ctx; + Some(options) + }; + }; + + // Clear the actual test method parameters. + Ok(quote! { + #(#attrs)* + #vis #sig { + #preamble + #block + } + }) +} + +#[derive(Debug)] +struct FunctionNameAndAttributes { + function_name: String, + arguments: Vec<(String, syn::Expr)>, +} + +fn name_from_expr(expr: &syn::Expr) -> Result { + match expr { + syn::Expr::Lit(lit) => match &lit.lit { + syn::Lit::Str(lit_str) => Ok(lit_str.value()), + _ => Err(syn::Error::new(lit.span(), "Unsupported literal type")), + }, + syn::Expr::Path(expr_path) => expr_path + .path + .get_ident() + .ok_or_else(|| { + syn::Error::new( + expr_path.span(), + "Expected an identifier in path expression", + ) + }) + .map(|ident| ident.to_string()), + syn::Expr::Field(expr_field) => { + // If it's a field, we can extract the base and path + // This assumes the field is a path like `az.foo.bar.namespace` + // and we want to extract `az.foo.bar.namespace` + let base = name_from_expr(expr_field.base.as_ref())?; + let member = match &expr_field.member { + Member::Named(ident) => ident.to_string(), + Member::Unnamed(_) => { + println!("Anonymous member"); + // If it's an unnamed member, we can use the index or some other identifier + // Here we assume it's a named member for simplicity + format!("{:?}", expr_field.member.to_token_stream()) + } + }; + Ok(format!("{base}.{member}")) + } + _ => Err(syn::Error::new(expr.span(), "Unsupported expression type")), + } +} + +impl Parse for FunctionNameAndAttributes { + fn parse(input: syn::parse::ParseStream) -> Result { + let function_name = input.parse::()?.value(); + // If the next character is a comma, we expect a list of attributes. + if input.peek(Token!(,)) { + input.parse::()?; + if input.peek(syn::token::Paren) { + let content; + let _ = syn::parenthesized!(content in input); + let mut arguments: syn::punctuated::Punctuated = + syn::punctuated::Punctuated::new(); + if !content.is_empty() { + arguments = content.parse_terminated(syn::ExprAssign::parse, syn::Token![,])?; + } + let arguments_result = arguments + .into_iter() + .map(|arg| { + let syn::ExprAssign { left, right, .. } = arg; + let (left, right) = (left, right); + let name = name_from_expr(left.as_ref())?; + Ok((name, *right)) + }) + .collect::>>()?; + + Ok(FunctionNameAndAttributes { + function_name, + arguments: arguments_result, + }) + } else { + Err(syn::Error::new( + input.span(), + "Expected parentheses after function name.", + )) + } + } else { + Ok(FunctionNameAndAttributes { + function_name, + arguments: vec![], + }) + } + } +} + +fn is_function_declaration(item: &TokenStream) -> bool { + let item_fn: ItemFn = match syn::parse2(item.clone()) { + Ok(fn_item) => fn_item, + Err(_) => { + return false; + } + }; + + // Function must be public. + if !matches!(item_fn.vis, syn::Visibility::Public(_)) { + return false; + } + + // Function must return a Result type. + if let syn::ReturnType::Type(_, ty) = &item_fn.sig.output { + if !matches!(ty.as_ref(), syn::Type::Path(_)) { + return false; + } + } else { + return false; + } + + true +} + +#[cfg(test)] +mod tests { + use std::vec; + + use syn::parse_quote; + + use super::*; + + #[test] + fn test_parse_function_name_and_attributes() { + { + let test_stream = quote! { "Test Function", (arg1 = 42, arg2 = "value") }; + let parsed: FunctionNameAndAttributes = + syn::parse2(test_stream).expect("Failed to parse"); + assert_eq!(parsed.function_name, "Test Function"); + assert_eq!( + parsed.arguments, + vec![ + ("arg1".to_string(), parse_quote!(42)), + ("arg2".to_string(), parse_quote!("value")) + ] + ); + } + } + #[test] + fn test_parse_function_name_and_attributes_with_string_name() { + { + let test_stream = quote! { "Test Function", ("az.namespace" = "my namespace", az.test_value = "value") }; + let parsed: FunctionNameAndAttributes = + syn::parse2(test_stream).expect("Failed to parse"); + assert_eq!(parsed.function_name, "Test Function"); + assert_eq!( + parsed.arguments, + vec![ + ("az.namespace".to_string(), parse_quote!("my namespace")), + ("az.test_value".to_string(), parse_quote!("value")), + ] + ); + } + } + #[test] + fn test_parse_function_name_and_attributes_with_dotted_name() { + { + let test_stream = quote! { "Test Function", (az.namespace = "my namespace", az.test_value = "value") }; + let parsed: FunctionNameAndAttributes = + syn::parse2(test_stream).expect("Failed to parse"); + assert_eq!(parsed.function_name, "Test Function"); + assert_eq!( + parsed.arguments, + vec![ + ("az.namespace".to_string(), parse_quote!("my namespace")), + ("az.test_value".to_string(), parse_quote!("value")) + ] + ); + } + } + #[test] + fn test_parse_function_name_and_attributes_with_identifier_argument() { + { + let test_stream = quote! {"macros_get_with_tracing", (az.path = path, az.info = "Test", az.number = 42)}; + let parsed: FunctionNameAndAttributes = + syn::parse2(test_stream).expect("Failed to parse"); + assert_eq!(parsed.function_name, "macros_get_with_tracing"); + assert_eq!( + parsed.arguments, + vec![ + ("az.path".to_string(), parse_quote!(path)), + ("az.info".to_string(), parse_quote!("Test")), + ("az.number".to_string(), parse_quote!(42)), + ] + ); + } + } + #[test] + fn test_parse_function_name_and_attributes_with_identifier_name() { + { + let test_stream = quote! { "Test Function", (az.foo.bar.namespace = "my namespace", az.test_value = "value") }; + let parsed: FunctionNameAndAttributes = + syn::parse2(test_stream).expect("Failed to parse"); + assert_eq!(parsed.function_name, "Test Function"); + assert_eq!( + parsed.arguments, + vec![ + ( + "az.foo.bar.namespace".to_string(), + parse_quote!("my namespace") + ), + ("az.test_value".to_string(), parse_quote!("value")) + ] + ); + } + } + #[test] + fn test_parse_function_name_and_attributes_with_comma_no_attributes() { + { + let test_stream = quote! { "Test Function", }; + + syn::parse2::(test_stream) + .expect_err("Should fail to parse."); + } + } + #[test] + fn test_parse_function_name_and_attributes_invalid_attribute_name() { + { + let test_stream = quote! { "Test Function",(23.5= "value") }; + + syn::parse2::(test_stream) + .expect_err("Should fail to parse."); + } + } + #[test] + fn test_parse_function_name_and_attributes_empty_attributes() { + { + let test_stream = quote! { "Test Function", ()}; + + syn::parse2::(test_stream).expect("No attributes are ok."); + } + } + + #[test] + fn test_is_function_declaration() { + let valid_fn = quote! { + pub async fn my_function() -> Result<(), Box> { + } + }; + let invalid_fn = quote! { + pub fn my_function() { + } + }; + + assert!(is_function_declaration(&valid_fn)); + assert!(!is_function_declaration(&invalid_fn)); + } + + #[test] + fn test_parse_function() -> std::result::Result<(), syn::Error> { + let attr = quote! { "TestFunction" }; + let item = quote! { + pub async fn my_function(&self, path: &str) -> Result<(), Box> { + let options = options.unwrap_or_default(); + + let mut url = self.endpoint.clone(); + url.set_path(path); + url.query_pairs_mut() + .append_pair("api-version", &self.api_version); + + let mut request = Request::new(url, azure_core::http::Method::Get); + + let response = self + .pipeline + .send(&options.method_options.context, &mut request) + .await?; + if !response.status().is_success() { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + }, + format!("Failed to GET {}: {}", request.url(), response.status()), + )); + } + Ok(response) + } + }; + + let actual = parse_function(attr, item)?; + let expected = quote! { + pub async fn my_function(&self, path: &str) -> Result<(), Box> { + let options = { + let mut options = options.unwrap_or_default(); + let public_api_info = azure_core::tracing::PublicApiInstrumentationInformation { + api_name: "TestFunction", + attributes: Vec::new(), + }; + let mut ctx = options.method_options.context.with_value(public_api_info); + if let Some(tracer) = &self.tracer { + ctx = ctx.with_value(tracer.clone()); + } + options.method_options.context = ctx; + Some(options) + }; + { + let options = options.unwrap_or_default(); + let mut url = self.endpoint.clone(); + url.set_path(path); + url.query_pairs_mut() + .append_pair("api-version", &self.api_version); + let mut request = Request::new(url, azure_core::http::Method::Get); + let response = self + .pipeline + .send(&options.method_options.context, &mut request) + .await?; + if !response.status().is_success() { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + }, + format!("Failed to GET {}: {}", request.url(), response.status()), + )); + } + Ok(response) + } + } + }; + + println!("Parsed tokens: {:?}", actual.to_string()); + println!("Expected tokens: {:?}", expected.to_string()); + + assert!( + crate::tracing::tests::compare_token_stream(actual, expected), + "Parsed tokens do not match expected tokens" + ); + Ok(()) + } + + // cspell: ignore deletedsecrets + #[test] + fn test_parse_pageable_function() -> std::result::Result<(), syn::Error> { + let attr = quote! { "TestFunction" }; + let item = quote! { + pub fn list_deleted_secret_properties( + &self, + options: Option>, + ) -> Result> { + let options = options.unwrap_or_default().into_owned(); + let pipeline = self.pipeline.clone(); + let mut first_url = self.endpoint.clone(); + first_url = first_url.join("deletedsecrets")?; + first_url + .query_pairs_mut() + .append_pair("api-version", &self.api_version); + if let Some(maxresults) = options.maxresults { + first_url + .query_pairs_mut() + .append_pair("maxresults", &maxresults.to_string()); + } + let api_version = self.api_version.clone(); + Ok(Pager::from_callback(move |next_link: Option| { + let url = match next_link { + Some(next_link) => { + let qp = next_link + .query_pairs() + .filter(|(name, _)| name.ne("api-version")); + let mut next_link = next_link.clone(); + next_link + .query_pairs_mut() + .clear() + .extend_pairs(qp) + .append_pair("api-version", &api_version); + next_link + } + None => first_url.clone(), + }; + let mut request = Request::new(url, Method::Get); + request.insert_header("accept", "application/json"); + let ctx = options.method_options.context.clone(); + let pipeline = pipeline.clone(); + async move { + let rsp: RawResponse = pipeline.send(&ctx, &mut request).await?; + let (status, headers, body) = rsp.deconstruct(); + let bytes = body.collect().await?; + let res: ListDeletedSecretPropertiesResult = json::from_json(&bytes)?; + let rsp = RawResponse::from_bytes(status, headers, bytes).into(); + Ok(match res.next_link { + Some(next_link) if !next_link.is_empty() => PagerResult::More { + response: rsp, + continuation: next_link.parse()?, + }, + _ => PagerResult::Done { response: rsp }, + }) + } + })) + } + }; + + let actual = parse_function(attr, item)?; + let expected = quote! { + pub fn list_deleted_secret_properties( + &self, + options: Option>, + ) -> Result> { + let options = { + let mut options = options.unwrap_or_default(); + let public_api_info = azure_core::tracing::PublicApiInstrumentationInformation { + api_name: "TestFunction", + attributes: Vec::new(), + }; + let mut ctx = options.method_options.context.with_value(public_api_info); + if let Some(tracer) = &self.tracer { + ctx = ctx.with_value(tracer.clone()); + } + options.method_options.context = ctx; + Some(options) + }; + { + let options = options.unwrap_or_default().into_owned(); + let pipeline = self.pipeline.clone(); + let mut first_url = self.endpoint.clone(); + first_url = first_url.join("deletedsecrets")?; + first_url + .query_pairs_mut() + .append_pair("api-version", &self.api_version); + if let Some(maxresults) = options.maxresults { + first_url + .query_pairs_mut() + .append_pair("maxresults", &maxresults.to_string()); + } + let api_version = self.api_version.clone(); + Ok(Pager::from_callback(move |next_link: Option| { + let url = match next_link { + Some(next_link) => { + let qp = next_link + .query_pairs() + .filter(|(name, _)| name.ne("api-version")); + let mut next_link = next_link.clone(); + next_link + .query_pairs_mut() + .clear() + .extend_pairs(qp) + .append_pair("api-version", &api_version); + next_link + } + None => first_url.clone(), + }; + let mut request = Request::new(url, Method::Get); + request.insert_header("accept", "application/json"); + let ctx = options.method_options.context.clone(); + let pipeline = pipeline.clone(); + async move { + let rsp: RawResponse = pipeline.send(&ctx, &mut request).await?; + let (status, headers, body) = rsp.deconstruct(); + let bytes = body.collect().await?; + let res: ListDeletedSecretPropertiesResult = json::from_json(&bytes)?; + let rsp = RawResponse::from_bytes(status, headers, bytes).into(); + Ok(match res.next_link { + Some(next_link) if !next_link.is_empty() => PagerResult::More { + response: rsp, + continuation: next_link.parse()?, + }, + _ => PagerResult::Done { response: rsp }, + }) + } + })) + } + } + }; + + println!("Parsed tokens: {:?}", actual.to_string()); + println!("Expected tokens: {:?}", expected.to_string()); + + assert!( + crate::tracing::tests::compare_token_stream(actual, expected), + "Parsed tokens do not match expected tokens" + ); + Ok(()) + } +} diff --git a/sdk/core/azure_core_macros/src/tracing_new.rs b/sdk/core/azure_core_macros/src/tracing_new.rs new file mode 100644 index 0000000000..54b113d30d --- /dev/null +++ b/sdk/core/azure_core_macros/src/tracing_new.rs @@ -0,0 +1,815 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{ + parse::Parse, spanned::Spanned, AngleBracketedGenericArguments, ExprStruct, ItemFn, Result, +}; +use tracing::{error, trace}; + +const INVALID_SERVICE_CLIENT_NEW_MESSAGE: &str = + "new attribute must be applied to a public function which returns Self, a Result and/or Arc containing Self"; + +struct NamespaceAttribute { + client_namespace: String, +} + +impl Parse for NamespaceAttribute { + fn parse(input: syn::parse::ParseStream) -> Result { + let client_namespace = input.parse::()?.value(); + Ok(NamespaceAttribute { client_namespace }) + } +} + +fn parse_struct_expr( + client_namespace: &str, + struct_body: &ExprStruct, + default: TokenStream, + is_ok: bool, +) -> TokenStream { + if struct_body.path.is_ident("Self") { + let fields = struct_body.fields.iter(); + let tracer_init = quote! { + if let Some(tracer_options) = &options.client_options.request_instrumentation { + tracer_options + .tracer_provider + .as_ref() + .map(|tracer_provider| { + tracer_provider.get_tracer( + Some(#client_namespace), + option_env!("CARGO_PKG_NAME").unwrap_or("UNKNOWN"), + option_env!("CARGO_PKG_VERSION"), + ) + }) + } else { + None + } + }; + if is_ok { + quote! { + Ok(Self { + tracer: #tracer_init, + #(#fields),*, + }) + } + } else { + quote! { + Self { + tracer: #tracer_init, + #(#fields),*, + } + } + } + } else { + default + } +} + +fn is_arc_new_call(func: &syn::Expr) -> bool { + if let syn::Expr::Path(path) = func { + if path.path.segments.len() < 2 { + return false; + } + if path.path.segments[path.path.segments.len() - 2].ident != "Arc" { + return false; + } + if path.path.segments.last().unwrap().ident != "new" { + return false; + } + return true; + } + false +} + +// Parse a function call expression statement that initializes a struct with `Arc::new(Self {})` or `Ok(Arc::new(Self {}))`. +fn parse_call_expr(namespace: &str, call: &syn::ExprCall) -> TokenStream { + debug_assert_eq!( + call.args.len(), + 1, + "Call expression must have exactly one argument" + ); + if let syn::Expr::Path(path) = call.func.as_ref() { + if path.path.segments.last().unwrap().ident == "Ok" { + match call.args.first().unwrap() { + syn::Expr::Struct(struct_body) => { + parse_struct_expr(namespace, struct_body, call.to_token_stream(), true) + } + syn::Expr::Call(call) => { + // Let's make sure that we're doing a call to Arc::new before we recurse. + // Arc::new takes only a single argument, so we can check that first. + if call.args.len() != 1 { + trace!("Call expression does not have exactly one argument, emitting expression: {call:?}"); + return call.to_token_stream(); + } + if is_arc_new_call(call.func.as_ref()) { + let call_expr = parse_call_expr(namespace, call); + quote!(Ok(#call_expr)) + } else { + trace!("Call expression is not Arc::new(), emitting expression: {call:?}"); + call.to_token_stream() + } + } + _ => { + trace!( + "Call expression is not a struct or call, emitting expression: {call:?}" + ); + call.to_token_stream() + } + } + } else if is_arc_new_call(call.func.as_ref()) { + if let syn::Expr::Struct(struct_body) = call.args.first().unwrap() { + let struct_expr = + parse_struct_expr(namespace, struct_body, call.to_token_stream(), false); + quote! { + Arc::new(#struct_expr) + } + } else { + trace!("Call expression is not a struct, emitting expression: {call:?}"); + call.to_token_stream() + } + } else { + trace!("Call expression is not an Arc or Ok, emitting expression: {call:?}"); + call.to_token_stream() + } + } else { + trace!("Call expression is not a path, emitting expression: {call:?}"); + call.to_token_stream() + } +} + +/// Parse the token stream for an Azure Service client "new" declaration. +/// +/// An Azure Service client "new" declaration is a public function whose name starts with +/// `new` and returns either a new client instance or an error. +/// +/// This macro will ensure that the fn is public and returns one of the following: +/// 1) `Self` +/// 1) `Arc` +/// 1) `Result` +/// 1) `Result, E>` +/// +pub fn parse_new(attr: TokenStream, item: TokenStream) -> Result { + if let Err(reason) = is_new_declaration(&item) { + return Err(syn::Error::new( + item.span(), + format!("{INVALID_SERVICE_CLIENT_NEW_MESSAGE}: {reason}"), + )); + } + + let namespace_attrs: NamespaceAttribute = syn::parse2(attr)?; + + let ItemFn { + vis, + sig, + block, + attrs, + } = syn::parse2(item.clone())?; + + let ident = &sig.ident; + let inputs = sig.inputs.iter(); + let body =block.stmts.iter().map(|stmt| { + // Ensure that the body of the new function initializes the `tracer` field. + + match stmt { + syn::Stmt::Expr(expr, _) => + match expr { + syn::Expr::Call(c) => { + // If the expression is a call, we need to check if it is a struct initialization. + if c.args.len() != 1 { + trace!("Call expression does not have exactly one argument, emitting statement: {stmt:?}"); + // If the call does not have exactly one argument, just return it as is. + stmt.to_token_stream() + } + else { + parse_call_expr(namespace_attrs.client_namespace.as_str(), c) + } + } + syn::Expr::Struct(struct_body) => { + // If the expression is a struct, we need to parse it. + parse_struct_expr( + namespace_attrs.client_namespace.as_str(), + struct_body, + stmt.to_token_stream(), + false, + ) + } + _ => { + // If the expression is not a struct or call, just return it as is (for + // instance an "if" statement is an expression) + stmt.to_token_stream() + } + } + _ => { + // If the statement is not an expression, just return it as is. + stmt.to_token_stream() + } + } + }); + let output = &sig.output; + Ok(quote! { + #(#attrs)* + #vis + fn #ident(#(#inputs),*) #output { + #(#body)* + } + }) +} + +fn is_arc_of_self(path: &syn::Path) -> std::result::Result<(), String> { + let segment = path.segments.last().unwrap(); + if segment.ident != "Arc" { + error!( + "Invalid return type for new function: Arc must be the first segment, found {:?}", + segment.ident + ); + return Err( + "Invalid return type for new function: Arc must be the first segment".to_string(), + ); + } + if segment.arguments.is_empty() { + error!( + "Invalid return type for new function: Arc must have arguments, found {:?}", + segment.arguments + ); + return Err("Invalid return type for new function: Arc must have arguments".to_string()); + } + match &segment.arguments { + syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments { args, .. }) => { + if args.len() != 1 { + error!( + "Invalid return type for new function: Arc must have one argument, found {args:?}", + ); + return Err( + "Invalid return type for new function: Arc must have one argument".to_string(), + ); + } + if let syn::GenericArgument::Type(syn::Type::Path(path)) = &args[0] { + if path.path.is_ident("Self") { + Ok(()) + } else { + error!( + "Invalid return type for new function: Arc argument must be Self, found {:?}", + path.path + ); + Err( + "Invalid return type for new function: Arc argument must be Self" + .to_string(), + ) + } + } else { + error!( + "Invalid return type for new function: Arc argument must be Self, found {:?}", + args[0] + ); + Err("Invalid return type for new function: Arc argument must be Self".to_string()) + } + } + _ => { + error!("Invalid return type for new function: Arc arguments must be angle bracketed"); + Err( + "Invalid return type for new function: Arc arguments must be angle bracketed" + .to_string(), + ) + } + } +} + +fn is_valid_arc_of_self_call(expr: &syn::Expr) -> std::result::Result<(), String> { + if let syn::Expr::Struct(struct_body) = expr { + if struct_body.path.is_ident("Self") { + Ok(()) + } else { + error!( + "Invalid new function body: expected struct initialization with Self, found {:?}", + struct_body.path + ); + Err("expected struct initialization with Self".to_string()) + } + } else { + error!( + "Invalid new function body: expected call to `Arc`, found {:?}", + expr + ); + Err("expected last parameter to Arc to be Self".to_string()) + } +} +fn is_valid_ok_call( + args: &syn::punctuated::Punctuated, +) -> std::result::Result<(), String> { + if args.len() != 1 { + error!( + "Invalid new function body: expected call to `Ok` with one argument, found {args:?}" + ); + return Err( + "Invalid new function body: expected call to `Ok` with one argument".to_string(), + ); + } + match &args[0] { + syn::Expr::Struct(struct_body) => { + if struct_body.path.is_ident("Self") { + Ok(()) + } else { + error!( + "Invalid new function body: expected struct initialization with Self, found {:?}", + struct_body.path + ); + Err("expected struct initialization with Self".to_string()) + } + } + syn::Expr::Call(call) => { + if call.args.len() == 1 { + if is_arc_new_call(call.func.as_ref()) { + is_valid_arc_of_self_call(call.args.last().unwrap()) + } else { + error!( + "Invalid new function body: expected function named Arc, found {:?}", + call.func.as_ref() + ); + Err("expected Arc path".to_string()) + } + } else { + error!( + "Invalid new function body: expected call to function with one argument, found {:?}", + args[0] + ); + Err("expected call to Arc with one argument".to_string()) + } + } + _ => { + error!( + "Invalid new function body: expected a structure or call to function, found {:?}", + args[0] + ); + Err("Invalid new function body: expected a structure or call to function".to_string()) + } + } +} + +fn is_valid_new_body(stmts: &[syn::Stmt]) -> std::result::Result<(), String> { + if stmts.is_empty() { + return Err("New function body must have at least one statement".to_string()); + } + let last_stmt = stmts.last().unwrap(); + if let syn::Stmt::Expr(expr, _) = last_stmt { + match expr { + syn::Expr::Struct(struct_body) => { + if struct_body.path.is_ident("Self") { + Ok(()) + } else { + error!("Invalid new function body: expected struct initialization with Self, found {:?}", struct_body.path); + Err("Expected struct initialization with Self".to_string()) + } + } + syn::Expr::Call(call) => { + if let syn::Expr::Path(path) = call.func.as_ref() { + if path.path.is_ident("Ok") { + is_valid_ok_call(&call.args) + } else if is_arc_new_call(call.func.as_ref()) { + is_valid_arc_of_self_call(call.args.last().unwrap()) + } else { + error!( + "Invalid new function body: expected call to `Ok` or `Arc`, found {:?}", + path + ); + Err("Invalid new function body: expected call to `Ok` or `Arc`".to_string()) + } + } else { + error!( + "Invalid new function body - expected Path, got {:?}", + call.func + ); + Err("Invalid new function body - expected Path".to_string()) + } + } + _ => { + error!( + "Invalid new function body: expected call or struct statement, found {:?}", + last_stmt + ); + Err("Expected call or struct statement".to_string()) + } + } + } else { + error!( + "Invalid new function body: expected expression statement, found {:?}", + last_stmt + ); + Err("Expected final statement to be an expression".to_string()) + } +} + +fn is_valid_new_return(return_type: &syn::ReturnType) -> std::result::Result<(), String> { + match return_type { + syn::ReturnType::Default => Err("Default return type is not allowed".to_string()), + syn::ReturnType::Type(_, ty) => { + let syn::Type::Path(p) = ty.as_ref() else { + error!("Invalid return type for new function, expected path: {ty:?}"); + return Err("Invalid return type for new function, expected path".to_string()); + }; + if p.path.segments.is_empty() { + error!("Invalid return type for new function: Path is empty"); + return Err("Invalid return type for new function: Path is empty".to_string()); + } + if p.path.is_ident("Self") { + Ok(()) + } else { + // segments.last to allow for std::arc::Arc or azure_core::Result + let segment = p.path.segments.last().unwrap(); + + if segment.ident == "Result" { + match &segment.arguments { + syn::PathArguments::AngleBracketed(AngleBracketedGenericArguments { + args, + .. + }) => { + if args.len() != 1 && args.len() != 2 { + error!("Invalid return type for new function: Result must have one or two arguments"); + return Err("Invalid return type for new function: Result must have one or two arguments".to_string()); + } + if let syn::GenericArgument::Type(syn::Type::Path(path)) = &args[0] { + if path.path.is_ident("Self") { + Ok(()) + } else { + is_arc_of_self(&path.path) + } + } else { + error!("Invalid return type for new function: Result first argument must be Self, found {:?}", args[0]); + Err("Invalid return type for new function: Result first argument must be Self".to_string()) + } + } + _ => { + error!("Invalid return type for new function: Result arguments must be angle bracketed"); + Err("Invalid return type for new function: Result arguments must be angle bracketed".to_string()) + } + } + } else if segment.ident == "Arc" { + is_arc_of_self(&p.path) + } else { + Err("Invalid return type for new function: Expected Self, Result, or Arc".to_string()) + } + } + } + } +} + +/// Returns true if the item at the head of the token stream is a valid service client declaration. +/// +/// # Returns +/// - None if the item is a valid service new declaration. +/// - Some(String) if the item is NOT a valid service new declaration +fn is_new_declaration(item: &TokenStream) -> std::result::Result<(), String> { + // The item must be a function declaration. + let item_fn: ItemFn = syn::parse2(item.clone()) + .map_err(|e| format!("Failed to parse item as function declaration: {e}"))?; + + // Service clients new functions must be public. + if !matches!(item_fn.vis, syn::Visibility::Public(_)) { + error!("Service client new function must be public"); + Err("`tracing::new` function must be public".to_string()) + } else { + // Verify that this function returns a type that is either Self, Result, Arc, or Result, E>. + is_valid_new_return(&item_fn.sig.output)?; + // Look at the function body to ensure that the last statement is a struct initialization. + is_valid_new_body(&item_fn.block.stmts)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tracing::tests::setup_tracing; + + #[test] + fn is_new_declaration_simple_self() { + setup_tracing(); + assert!(is_new_declaration("e! {pub fn new_client(a:u32)-> Self { Self {}}}).is_ok()); + } + #[test] + fn is_new_declaration_arc_self() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> Arc { Arc::new(Self {})}} + ) + .is_ok()); + } + #[test] + fn is_new_declaration_arc_self_long() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> std::sync::Arc { std::sync::Arc::new(Self {})}} + ).is_ok()); + } + #[test] + fn is_new_declaration_result_self() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> Result { Ok(Self {})}} + ) + .is_ok()); + } + #[test] + fn is_new_declaration_result_self_std_result() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> std::result::Result { Ok(Self {})}} + ).is_ok()); + } + #[test] + fn is_new_declaration_result_arc_self() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> Result> { Ok(Arc::new(Self {}) )}} + ) + .is_ok()); + } + #[test] + fn is_new_declaration_result_arc_self_long() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> Result> { Ok(std::sync::Arc::new(Self {}) )}} + ) + .is_ok()); + } + #[test] + fn is_new_declaration_invalid_return_type() { + setup_tracing(); + + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> u64 { Ok(Arc::new(Self {}) )}} + ) + .is_err()); + } + #[test] + fn is_new_declaration_result_not_self() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> Result { Ok(Arc::new(Self {}) )}} + ) + .is_err()); + } + #[test] + fn is_new_declaration_result_not_arc_self() { + setup_tracing(); + assert!(is_new_declaration( + "e! {pub fn new_client(a:u32)-> Result> { Ok(Arc::new(Self {}) )}} + ) + .is_err()); + } + + #[test] + fn parse_new_function() { + setup_tracing(); + let attr = quote!("Az.Namespace"); + let item = quote! { + pub fn new_service_client(name: &'static str, endpoint: Url) -> Self { + let function = newtype::new(); + println!("Function: {:?}", function); + i = i + 1; + let this = Self { + name, + endpoint, + }; + Self { + name, + endpoint, + } + } + }; + let actual = parse_new(attr, item).expect("Failed to parse new function declaration"); + println!("Parsed tokens: {actual}"); + + let expected = quote! { + pub fn new_service_client(name: &'static str, endpoint: Url) -> Self { + let function = newtype::new(); + println!("Function: {:?}", function); + i = i + 1; + let this = Self { + name, + endpoint, + }; + + Self { + tracer: if let Some(tracer_options) = + &options.client_options.request_instrumentation + { + tracer_options + .tracer_provider + .as_ref() + .map(|tracer_provider| { + tracer_provider.get_tracer( + Some("Az.Namespace"), + option_env!("CARGO_PKG_NAME").unwrap_or("UNKNOWN"), + option_env!("CARGO_PKG_VERSION"), + ) + }) + } else { + None + }, + name, + endpoint, + } + } + }; + assert!( + crate::tracing::tests::compare_token_stream(actual, expected), + "Parsed tokens do not match expected tokens" + ); + } + + #[test] + fn parse_generated_new() { + setup_tracing(); + let attr = quote!("Az.GeneratedNamespace"); + let new_function = quote! { + pub fn new( + endpoint: &str, + credential: Arc, + options: Option, + ) -> Result { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + let auth_policy: Arc = Arc::new(BearerTokenCredentialPolicy::new( + credential, + vec!["https://vault.azure.net/.default"], + )); + Ok(Self { + endpoint, + api_version: options.api_version, + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.client_options, + Vec::default(), + vec![auth_policy], + ), + }) + } + }; + let actual = + parse_new(attr, new_function).expect("Failed to parse new function declaration"); + + println!("Parsed tokens: {actual}"); + + // I am not at all sure why the parameters to `new` are not being parsed correctly - + // the trailing comma in the `new_function` token stream is not present. + let expected = quote! { + pub fn new( + endpoint: &str, + credential: Arc, + options: Option + ) -> Result { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + let auth_policy: Arc = Arc::new(BearerTokenCredentialPolicy::new( + credential, + vec!["https://vault.azure.net/.default"], + )); + Ok(Self { + tracer: if let Some(tracer_options) = + &options.client_options.request_instrumentation + { + tracer_options + .tracer_provider + .as_ref() + .map(|tracer_provider| { + tracer_provider.get_tracer( + Some("Az.GeneratedNamespace"), + option_env!("CARGO_PKG_NAME").unwrap_or("UNKNOWN"), + option_env!("CARGO_PKG_VERSION"), + ) + }) + } else { + None + }, + endpoint, + api_version: options.api_version, + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.client_options, + Vec::default(), + vec![auth_policy], + ), + }) + } + }; + assert!( + crate::tracing::tests::compare_token_stream(actual, expected), + "Parsed tokens do not match expected tokens" + ); + } + + #[test] + fn parse_arc_new() { + setup_tracing(); + let attr = quote!("Az.GeneratedNamespace"); + let new_function = quote! { + pub fn new( + endpoint: &str, + credential: Arc, + options: Option, + ) -> Result> { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + let auth_policy: Arc = Arc::new(BearerTokenCredentialPolicy::new( + credential, + vec!["https://vault.azure.net/.default"], + )); + Ok(Arc::new(Self { + endpoint, + api_version: options.api_version, + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.client_options, + Vec::default(), + vec![auth_policy], + ), + })) + } + }; + let actual = + parse_new(attr, new_function).expect("Failed to parse new function declaration"); + + println!("Parsed tokens: {actual}"); + + // I am not at all sure why the parameters to `new` are not being parsed correctly - + // the trailing comma in the `new_function` token stream is not present. + let expected = quote! { + pub fn new( + endpoint: &str, + credential: Arc, + options: Option + ) -> Result> { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + let auth_policy: Arc = Arc::new(BearerTokenCredentialPolicy::new( + credential, + vec!["https://vault.azure.net/.default"], + )); + Ok(Arc::new(Self { + tracer: if let Some(tracer_options) = + &options.client_options.request_instrumentation + { + tracer_options + .tracer_provider + .as_ref() + .map(|tracer_provider| { + tracer_provider.get_tracer( + Some("Az.GeneratedNamespace"), + option_env!("CARGO_PKG_NAME").unwrap_or("UNKNOWN"), + option_env!("CARGO_PKG_VERSION"), + ) + }) + } else { + None + }, + endpoint, + api_version: options.api_version, + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.client_options, + Vec::default(), + vec![auth_policy], + ), + })) + } + }; + assert!( + crate::tracing::tests::compare_token_stream(actual, expected), + "Parsed tokens do not match expected tokens" + ); + } +} diff --git a/sdk/core/azure_core_macros/src/tracing_subclient.rs b/sdk/core/azure_core_macros/src/tracing_subclient.rs new file mode 100644 index 0000000000..4773d8c792 --- /dev/null +++ b/sdk/core/azure_core_macros/src/tracing_subclient.rs @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +use proc_macro2::TokenStream; +use quote::{quote, ToTokens}; +use syn::{spanned::Spanned, ExprStruct, ItemFn, Result}; +use tracing::error; + +const INVALID_SUBCLIENT_MESSAGE: &str = + "subclient attribute must be applied to a public function which returns a client type"; + +/// Parse the token stream for an Azure Service subclient declaration. +/// +/// An Azure Service client is a public struct that represents a client for an Azure service. +/// +/// +pub fn parse_subclient(_attr: TokenStream, item: TokenStream) -> Result { + if !is_subclient_declaration(&item) { + return Err(syn::Error::new(item.span(), INVALID_SUBCLIENT_MESSAGE)); + } + + let ItemFn { + vis, + sig, + block, + attrs, + } = syn::parse2(item.clone())?; + + let body = block.stmts; + + let ExprStruct { fields, path, .. } = syn::parse2(body.first().unwrap().to_token_stream())?; + + let fields = fields.iter(); + + Ok(quote! { + #(#attrs)* + #vis #sig { + #path { + #(#fields),*, + tracer: self.tracer.clone(), + } + } + }) +} + +fn is_subclient_declaration(item: &TokenStream) -> bool { + let ItemFn { + vis, block, sig, .. + } = match syn::parse2(item.clone()) { + Ok(fn_item) => fn_item, + Err(e) => { + error!("Failed to parse function: {}", e); + return false; + } + }; + + // Subclient constructors must be public functions. + if !matches!(vis, syn::Visibility::Public(_)) { + error!("Subclient constructors must be public functions"); + return false; + } + + // Subclient constructors must have a body with a single statement. + if block.stmts.len() != 1 { + error!("Subclient constructors must have a single statement in their body"); + return false; + } + + // Subclient constructors must have a return type that is a client type. + if let syn::ReturnType::Type(_, ty) = &sig.output { + if !matches!(ty.as_ref(), syn::Type::Path(p) if p.path.segments.last().unwrap().ident.to_string().ends_with("Client")) + { + return false; + } + } else { + return false; + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tracing::tests::setup_tracing; + use proc_macro2::TokenStream; + use quote::quote; + use tracing::trace; + + #[test] + fn test_is_subclient_declaration() { + setup_tracing(); + assert!(is_subclient_declaration("e! { + pub fn get_operation_templates_lro_client(&self) -> OperationTemplatesLroClient { + OperationTemplatesLroClient { + api_version: self.api_version.clone(), + endpoint: self.endpoint.clone(), + pipeline: self.pipeline.clone(), + subscription_id: self.subscription_id.clone(), + } + } + }),); + + assert!(!is_subclient_declaration("e! { + pub fn not_a_subclient() {} + })); + + assert!(is_subclient_declaration("e! { + pub fn operation_templates_lro_client() -> OperationTemplatesLroClient { + OperationTemplatesLroClient { + api_version: "2021-01-01".to_string(), + endpoint: "https://example.com".to_string(), + pipeline: "pipeline".to_string(), + subscription_id: "subscription_id".to_string(), + } + } + })); + } + + #[test] + fn test_parse_subclient() { + setup_tracing(); + let attr = TokenStream::new(); + let item = quote! { + pub fn get_operation_templates_lro_client(&self) -> OperationTemplatesLroClient { + OperationTemplatesLroClient { + api_version: self.api_version.clone(), + endpoint: self.endpoint.clone(), + pipeline: self.pipeline.clone(), + subscription_id: self.subscription_id.clone(), + } + } + }; + + let actual = parse_subclient(attr.clone(), item.clone()) + .expect("Failed to parse subclient declaration"); + trace!("Actual:{actual}"); + let expected = quote! { + pub fn get_operation_templates_lro_client(&self) -> OperationTemplatesLroClient { + OperationTemplatesLroClient { + api_version: self.api_version.clone(), + endpoint: self.endpoint.clone(), + pipeline: self.pipeline.clone(), + subscription_id: self.subscription_id.clone(), + tracer: self.tracer.clone(), + } + } + }; + assert!( + crate::tracing::tests::compare_token_stream(actual, expected), + "Parsed tokens do not match expected tokens" + ); + } +} diff --git a/sdk/core/azure_core_opentelemetry/Cargo.toml b/sdk/core/azure_core_opentelemetry/Cargo.toml index 5dd0903b06..729bfdb6fd 100644 --- a/sdk/core/azure_core_opentelemetry/Cargo.toml +++ b/sdk/core/azure_core_opentelemetry/Cargo.toml @@ -16,15 +16,22 @@ edition.workspace = true [dependencies] azure_core.workspace = true -log.workspace = true opentelemetry = { version = "0.30", features = ["trace"] } +opentelemetry-http = "0.30.0" +opentelemetry_sdk = "0.30" +reqwest.workspace = true tracing.workspace = true typespec_client_core.workspace = true + [dev-dependencies] +azure_core_test = { workspace = true, features = ["tracing"] } +azure_core_test_macros.workspace = true +azure_identity.workspace = true opentelemetry_sdk = { version = "0.30", features = ["testing"] } tokio.workspace = true tracing-subscriber.workspace = true +url.workspace = true [lints] workspace = true diff --git a/sdk/core/azure_core_opentelemetry/README.md b/sdk/core/azure_core_opentelemetry/README.md index f2679a1302..033e4c2ca2 100644 --- a/sdk/core/azure_core_opentelemetry/README.md +++ b/sdk/core/azure_core_opentelemetry/README.md @@ -1,17 +1,105 @@ # Azure Core OpenTelemetry Tracing -This crate provides OpenTelemetry distributed tracing support for the Azure SDK for Rust. It enables automatic span creation, context propagation, and telemetry collection for Azure service operations. +This crate provides [OpenTelemetry](https://opentelemetry.io) distributed tracing support for the Azure SDK for Rust. +It bridges the standardized `azure_core` tracing traits with the OpenTelemetry for Rust implementation, +enabling automatic span creation, context propagation, and telemetry collection for Azure services. -## Features +It allows Rust applications which use the [OpenTelemetry](https://opentelemetry.io/) APIs to generate OpenTelemetry spans for Azure SDK for Rust Clients. -## Usage +## OpenTelemetry integration with the Azure SDK for Rust -### Basic Setup +To integrate the OpenTelemetry APIs with the Azure SDK for Rust, you create a `OpenTelemetryTracerProvider` and pass it into your SDK ClientOptions. -### Creating Spans +```rust no_run +# use azure_identity::DefaultAzureCredential; +# use azure_core::{http::{ClientOptions, RequestInstrumentationOptions}}; +# #[derive(Default)] +# struct ServiceClientOptions { +# client_options: ClientOptions, +# } +use azure_core_opentelemetry::OpenTelemetryTracerProvider; +use opentelemetry_sdk::trace::SdkTracerProvider; +use std::sync::Arc; -### Error Handling +# fn test_fn() -> azure_core::Result<()> { +// Create an OpenTelemetry tracer provider adapter from an OpenTelemetry TracerProvider +let otel_tracer_provider = Arc::new(SdkTracerProvider::builder().build()); -## Azure Conventions +let azure_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); -## Integration +let options = ServiceClientOptions { + client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, + ..Default::default() + }; + +# Ok(()) +# } +``` + +If it is more convenient to use the global OpenTelemetry provider, then the `OpenTelemetryTracerProvider::new_from_global_provider` method will configure the OpenTelemetry support to use the global provider instead of a custom configured provider. + +```rust no_run +# use azure_identity::DefaultAzureCredential; +# use azure_core::{http::{ClientOptions, RequestInstrumentationOptions}}; + +# #[derive(Default)] +# struct ServiceClientOptions { +# client_options: ClientOptions, +# } +use azure_core_opentelemetry::OpenTelemetryTracerProvider; +use opentelemetry_sdk::trace::SdkTracerProvider; +use std::sync::Arc; + +# fn test_fn() -> azure_core::Result<()> { + +let azure_provider = OpenTelemetryTracerProvider::new_from_global_provider(); + +let options = ServiceClientOptions { + client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, +}; + +# Ok(()) +# } +``` + +Once the `OpenTelemetryTracerProvider` is integrated with the Azure Service ClientOptions, the Azure SDK will be configured to capture per-API and per-HTTP operation tracing options, and the HTTP requests will be annotated with [W3C Trace Context headers](https://www.w3.org/TR/trace-context/). + +## Troubleshooting + +## General + +## Logging + +## Contributing + +See the [CONTRIBUTING.md] for details on building, testing, and contributing to these libraries. + +This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit . + +When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repos using our CLA. + +This project has adopted the [Microsoft Open Source Code of Conduct]. For more information see the [Code of Conduct FAQ] or contact with any additional questions or comments. + +## Reporting security issues and security bugs + +Security issues and bugs should be reported privately, via email, to the Microsoft Security Response Center (MSRC) . You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Further information, including the MSRC PGP key, can be found in the [Security TechCenter](https://www.microsoft.com/msrc/faqs-report-an-issue). + +## License + +Azure SDK for Rust is licensed under the [MIT](https://github.com/Azure/azure-sdk-for-cpp/blob/main/LICENSE.txt) license. + + + +[Microsoft Open Source Code of Conduct]: https://opensource.microsoft.com/codeofconduct/ +[CONTRIBUTING.md]: https://github.com/Azure/azure-sdk-for-rust/blob/main/CONTRIBUTING.md +[Code of Conduct FAQ]: https://opensource.microsoft.com/codeofconduct/faq/ diff --git a/sdk/core/azure_core_opentelemetry/assets.json b/sdk/core/azure_core_opentelemetry/assets.json new file mode 100644 index 0000000000..5f90bbd318 --- /dev/null +++ b/sdk/core/azure_core_opentelemetry/assets.json @@ -0,0 +1,6 @@ +{ + "AssetsRepo": "Azure/azure-sdk-assets", + "AssetsRepoPrefixPath": "rust", + "Tag": "rust/azure_core_opentelemetry_58be03e82f", + "TagPrefix": "rust/azure_core_opentelemetry" +} \ No newline at end of file diff --git a/sdk/core/azure_core_opentelemetry/src/attributes.rs b/sdk/core/azure_core_opentelemetry/src/attributes.rs index b15101cc7c..086ccfa131 100644 --- a/sdk/core/azure_core_opentelemetry/src/attributes.rs +++ b/sdk/core/azure_core_opentelemetry/src/attributes.rs @@ -5,13 +5,17 @@ // Re-export typespec_client_core tracing attributes for convenience use azure_core::tracing::{ - AttributeArray as AzureAttributeArray, AttributeValue as AzureAttributeValue, + Attribute as AzureAttribute, AttributeArray as AzureAttributeArray, + AttributeValue as AzureAttributeValue, }; +use opentelemetry::KeyValue; pub(super) struct AttributeArray(AzureAttributeArray); pub(super) struct AttributeValue(pub AzureAttributeValue); +pub(super) struct OpenTelemetryAttribute(pub AzureAttribute); + impl From for AttributeValue { fn from(value: bool) -> Self { AttributeValue(AzureAttributeValue::Bool(value)) @@ -24,9 +28,9 @@ impl From for AttributeValue { } } -impl From for AttributeValue { - fn from(value: u64) -> Self { - AttributeValue(AzureAttributeValue::U64(value)) +impl From for AttributeValue { + fn from(value: f64) -> Self { + AttributeValue(AzureAttributeValue::I64(value as i64)) } } @@ -47,10 +51,9 @@ impl From> for AttributeArray { AttributeArray(AzureAttributeArray::I64(values)) } } - -impl From> for AttributeArray { - fn from(values: Vec) -> Self { - AttributeArray(AzureAttributeArray::U64(values)) +impl From> for AttributeArray { + fn from(values: Vec) -> Self { + AttributeArray(AzureAttributeArray::F64(values)) } } @@ -60,13 +63,22 @@ impl From> for AttributeArray { } } +impl From for KeyValue { + fn from(attr: OpenTelemetryAttribute) -> Self { + KeyValue::new( + opentelemetry::Key::from(attr.0.key.to_string()), + opentelemetry::Value::from(AttributeValue(attr.0.value)), + ) + } +} + /// Conversion from typespec_client_core AttributeValue to OpenTelemetry Value impl From for opentelemetry::Value { fn from(value: AttributeValue) -> Self { match value.0 { AzureAttributeValue::Bool(b) => opentelemetry::Value::Bool(b), AzureAttributeValue::I64(i) => opentelemetry::Value::I64(i), - AzureAttributeValue::U64(u) => opentelemetry::Value::I64(u as i64), + AzureAttributeValue::F64(f) => opentelemetry::Value::F64(f), AzureAttributeValue::String(s) => opentelemetry::Value::String(s.into()), AzureAttributeValue::Array(arr) => { opentelemetry::Value::Array(opentelemetry::Array::from(AttributeArray(arr))) @@ -81,10 +93,7 @@ impl From for opentelemetry::Array { match array.0 { AzureAttributeArray::Bool(values) => values.into(), AzureAttributeArray::I64(values) => values.into(), - AzureAttributeArray::U64(values) => { - let i64_values: Vec = values.into_iter().map(|v| v as i64).collect(); - i64_values.into() - } + AzureAttributeArray::F64(values) => values.into(), AzureAttributeArray::String(values) => { let string_values: Vec = values.into_iter().map(|s| s.into()).collect(); diff --git a/sdk/core/azure_core_opentelemetry/src/lib.rs b/sdk/core/azure_core_opentelemetry/src/lib.rs index a192707bd3..37bed17edb 100644 --- a/sdk/core/azure_core_opentelemetry/src/lib.rs +++ b/sdk/core/azure_core_opentelemetry/src/lib.rs @@ -1,11 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -//! Azure Core OpenTelemetry tracing integration. -//! -//! This crate provides OpenTelemetry distributed tracing support for the Azure SDK for Rust. -//! It bridges the standardized typespec_client_core tracing traits with OpenTelemetry implementation, -//! enabling automatic span creation, context propagation, and telemetry collection for Azure services. +#![doc = include_str!("../README.md")] +#![cfg_attr(docsrs, feature(doc_auto_cfg))] mod attributes; mod span; diff --git a/sdk/core/azure_core_opentelemetry/src/span.rs b/sdk/core/azure_core_opentelemetry/src/span.rs index 2981e3c10c..a9ee351527 100644 --- a/sdk/core/azure_core_opentelemetry/src/span.rs +++ b/sdk/core/azure_core_opentelemetry/src/span.rs @@ -5,10 +5,12 @@ use crate::attributes::AttributeValue as ConversionAttributeValue; use azure_core::{ + http::headers::{HeaderName, HeaderValue}, tracing::{AsAny, AttributeValue, Span, SpanGuard, SpanStatus}, - Result, }; -use opentelemetry::trace::TraceContextExt; +use opentelemetry::{propagation::TextMapPropagator, trace::TraceContextExt}; +use opentelemetry_http::HeaderInjector; +use opentelemetry_sdk::propagation::TraceContextPropagator; use std::{error::Error as StdError, sync::Arc}; /// newtype for Azure Core SpanKind to enable conversion to OpenTelemetry SpanKind @@ -41,52 +43,78 @@ impl OpenTelemetrySpan { } impl Span for OpenTelemetrySpan { - fn end(&self) -> Result<()> { + fn is_recording(&self) -> bool { + self.context.span().is_recording() + } + + fn end(&self) { self.context.span().end(); - Ok(()) } fn span_id(&self) -> [u8; 8] { self.context.span().span_context().span_id().to_bytes() } - fn set_attribute(&self, key: &'static str, value: AttributeValue) -> Result<()> { + fn set_attribute(&self, key: &'static str, value: AttributeValue) { let otel_value = opentelemetry::Value::from(ConversionAttributeValue(value)); self.context .span() .set_attribute(opentelemetry::KeyValue::new(key, otel_value)); - Ok(()) } - fn record_error(&self, error: &dyn StdError) -> Result<()> { + fn record_error(&self, error: &dyn StdError) { self.context.span().record_error(error); self.context .span() .set_status(opentelemetry::trace::Status::error(error.to_string())); - Ok(()) } - fn set_status(&self, status: SpanStatus) -> Result<()> { + fn set_status(&self, status: SpanStatus) { let otel_status = match status { SpanStatus::Unset => opentelemetry::trace::Status::Unset, - SpanStatus::Ok => opentelemetry::trace::Status::Ok, SpanStatus::Error { description } => opentelemetry::trace::Status::error(description), }; self.context.span().set_status(otel_status); - Ok(()) } - fn set_current( - &self, - _context: &azure_core::http::Context, - ) -> typespec_client_core::Result> { + fn propagate_headers(&self, request: &mut azure_core::http::Request) { + // A TraceContextPropagator is used to inject trace context information into HTTP headers. + let trace_propagator = TraceContextPropagator::new(); + // We need to map between a reqwest header map (which is what the OpenTelemetry SDK requires) + // and the Azure Core request headers. + // + // We start with an empty header map and inject the OpenTelemetry headers into it. + let mut header_map = reqwest::header::HeaderMap::new(); + trace_propagator.inject_context(&self.context, &mut HeaderInjector(&mut header_map)); + + // We then insert each of the headers from the OpenTelemetry header map into the + // Request's header map. + for (key, value) in header_map { + // Note: The OpenTelemetry HeaderInjector will always produce unique header names, so we don't need to + // handle the multiple headers case here. + + if let Some(key) = key { + request.insert_header( + HeaderName::from(key.as_str().to_owned()), + // The value is guaranteed to be a valid UTF-8 string by the OpenTelemetry SDK, + // so we can safely unwrap it. + HeaderValue::from(value.to_str().unwrap().to_owned()), + ); + } else { + // If the key is a duplicate of the previous header, we ignore it + tracing::warn!("Duplicate header key detected. Skipping this header."); + } + } + } + + fn set_current(&self, _context: &azure_core::http::Context) -> Box { // Create a context with the current span let context_guard = self.context.clone().attach(); - Ok(Box::new(OpenTelemetrySpanGuard { + Box::new(OpenTelemetrySpanGuard { _inner: context_guard, - })) + }) } } @@ -101,9 +129,8 @@ struct OpenTelemetrySpanGuard { } impl SpanGuard for OpenTelemetrySpanGuard { - fn end(self) -> Result<()> { + fn end(self) { // The span is ended when the guard is dropped, so no action needed here. - Ok(()) } } @@ -116,13 +143,14 @@ impl Drop for OpenTelemetrySpanGuard { #[cfg(test)] mod tests { use crate::telemetry::OpenTelemetryTracerProvider; - use azure_core::http::Context as AzureContext; - use azure_core::tracing::{AttributeValue, SpanKind, SpanStatus, TracerProvider}; + use azure_core::http::{Context as AzureContext, Url}; + use azure_core::tracing::{Attribute, AttributeValue, SpanKind, SpanStatus, TracerProvider}; use opentelemetry::trace::TraceContextExt; use opentelemetry::{Context, Key, KeyValue, Value}; use opentelemetry_sdk::trace::{in_memory_exporter::InMemorySpanExporter, SdkTracerProvider}; use std::io::{Error, ErrorKind}; use std::sync::Arc; + use tracing::trace; fn create_exportable_tracer_provider() -> (Arc, InMemorySpanExporter) { let otel_exporter = InMemorySpanExporter::default(); @@ -138,10 +166,10 @@ mod tests { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); - let span = tracer.start_span("test_span", SpanKind::Client).unwrap(); - assert!(span.end().is_ok()); + let tracer = + tracer_provider.get_tracer(Some("Microsoft.SpecialCase"), "test", Some("0.1.0")); + let span = tracer.start_span("test_span", SpanKind::Client, vec![]); + span.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 1); @@ -153,19 +181,45 @@ mod tests { } } + // cspell: ignore traceparent tracestate + #[test] + fn test_open_telemetry_span_propagate() { + let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); + + let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); + let tracer = tracer_provider.get_tracer(Some("Microsoft.SpecialCase"), "test", None); + let span = tracer.start_span("test_span", SpanKind::Client, vec![]); + let mut request = azure_core::http::Request::new( + Url::parse("http://example.com").unwrap(), + azure_core::http::Method::Get, + ); + span.propagate_headers(&mut request); + trace!("Request headers after propagation: {:?}", request.headers()); + let traceparent = azure_core::http::headers::HeaderName::from("traceparent"); + let tracestate = azure_core::http::headers::HeaderName::from("tracestate"); + request.headers().get_as::(&traceparent).unwrap(); + request.headers().get_as::(&tracestate).unwrap(); + span.end(); + + let finished_spans = otel_exporter.get_finished_spans().unwrap(); + assert_eq!(finished_spans.len(), 1); + } + #[test] fn test_open_telemetry_span_hierarchy() { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); - let parent_span = tracer.start_span("parent_span", SpanKind::Server).unwrap(); - let child_span = tracer - .start_span_with_parent("child_span", SpanKind::Client, parent_span.clone()) - .unwrap(); + let tracer = tracer_provider.get_tracer(Some("Special Name"), "test", Some("0.1.0")); + let parent_span = tracer.start_span("parent_span", SpanKind::Server, vec![]); + let child_span = tracer.start_span_with_parent( + "child_span", + SpanKind::Client, + vec![], + parent_span.clone(), + ); - assert!(child_span.end().is_ok()); - assert!(parent_span.end().is_ok()); + child_span.end(); + parent_span.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 2); @@ -186,17 +240,15 @@ mod tests { fn test_open_telemetry_span_start_with_parent() { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); - let span1 = tracer.start_span("span1", SpanKind::Internal).unwrap(); - let span2 = tracer.start_span("span2", SpanKind::Server).unwrap(); - let child_span = tracer - .start_span_with_parent("child_span", SpanKind::Client, span1.clone()) - .unwrap(); - - assert!(child_span.end().is_ok()); - assert!(span2.end().is_ok()); - assert!(span1.end().is_ok()); + let tracer = tracer_provider.get_tracer(Some("MyNamespace"), "test", Some("0.1.0")); + let span1 = tracer.start_span("span1", SpanKind::Internal, vec![]); + let span2 = tracer.start_span("span2", SpanKind::Server, vec![]); + let child_span = + tracer.start_span_with_parent("child_span", SpanKind::Client, vec![], span1.clone()); + + child_span.end(); + span2.end(); + span1.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 3); @@ -217,20 +269,15 @@ mod tests { fn test_open_telemetry_span_start_with_current() { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); - let span1 = tracer.start_span("span1", SpanKind::Internal).unwrap(); - let span2 = tracer.start_span("span2", SpanKind::Server).unwrap(); - let _span_guard = span2 - .set_current(&azure_core::http::Context::new()) - .unwrap(); - let child_span = tracer - .start_span_with_current("child_span", SpanKind::Client) - .unwrap(); - - assert!(child_span.end().is_ok()); - assert!(span2.end().is_ok()); - assert!(span1.end().is_ok()); + let tracer = tracer_provider.get_tracer(Some("Namespace"), "test", Some("0.1.0")); + let span1 = tracer.start_span("span1", SpanKind::Internal, vec![]); + let span2 = tracer.start_span("span2", SpanKind::Server, vec![]); + let _span_guard = span2.set_current(&azure_core::http::Context::new()); + let child_span = tracer.start_span("child_span", SpanKind::Client, vec![]); + + child_span.end(); + span2.end(); + span1.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 3); @@ -251,14 +298,11 @@ mod tests { fn test_open_telemetry_span_set_attribute() { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); - let span = tracer.start_span("test_span", SpanKind::Internal).unwrap(); + let tracer = tracer_provider.get_tracer(Some("ThisNamespace"), "test", Some("0.1.0")); + let span = tracer.start_span("test_span", SpanKind::Internal, vec![]); - assert!(span - .set_attribute("test_key", AttributeValue::String("test_value".to_string())) - .is_ok()); - assert!(span.end().is_ok()); + span.set_attribute("test_key", AttributeValue::String("test_value".to_string())); + span.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 1); @@ -277,13 +321,12 @@ mod tests { fn test_open_telemetry_span_record_error() { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); - let span = tracer.start_span("test_span", SpanKind::Client).unwrap(); + let tracer = tracer_provider.get_tracer(Some("namespace"), "test", Some("0.1.0")); + let span = tracer.start_span("test_span", SpanKind::Client, vec![]); let error = Error::new(ErrorKind::NotFound, "resource not found"); - assert!(span.record_error(&error).is_ok()); - assert!(span.end().is_ok()); + span.record_error(&error); + span.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 1); @@ -304,44 +347,36 @@ mod tests { fn test_open_telemetry_span_set_status() { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); + let tracer = tracer_provider.get_tracer(Some("Namespace"), "test", Some("0.1.0")); - // Test Ok status - let span = tracer.start_span("test_span_ok", SpanKind::Server).unwrap(); - assert!(span.set_status(SpanStatus::Ok).is_ok()); - assert!(span.end().is_ok()); + // Test Unset status + let span = tracer.start_span("test_span_unset", SpanKind::Server, vec![]); + span.end(); // Test Error status - let span = tracer - .start_span("test_span_error", SpanKind::Server) - .unwrap(); - assert!(span - .set_status(SpanStatus::Error { - description: "test error".to_string() - }) - .is_ok()); - assert!(span.end().is_ok()); + let span = tracer.start_span("test_span_error", SpanKind::Server, vec![]); + span.set_status(SpanStatus::Error { + description: "test error".to_string(), + }); + span.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 2); - let ok_span = spans.iter().find(|s| s.name == "test_span_ok").unwrap(); - assert_eq!(ok_span.status, opentelemetry::trace::Status::Ok); - let error_span = spans.iter().find(|s| s.name == "test_span_error").unwrap(); assert_eq!( error_span.status, opentelemetry::trace::Status::error("test error") ); + let unset_span = spans.iter().find(|s| s.name == "test_span_unset").unwrap(); + assert_eq!(unset_span.status, opentelemetry::trace::Status::Unset); } #[tokio::test] async fn test_open_telemetry_span_futures() { let (otel_tracer_provider, otel_exporter) = create_exportable_tracer_provider(); let tracer_provider = OpenTelemetryTracerProvider::new(otel_tracer_provider); - assert!(tracer_provider.is_ok()); - let tracer = tracer_provider.unwrap().get_tracer("test", "0.1.0"); + let tracer = tracer_provider.get_tracer(Some("Namespace"), "test", Some("0.1.0")); let future = async { let context = Context::current(); @@ -355,28 +390,36 @@ mod tests { 42 }; - let span = tracer.start_span("test_span", SpanKind::Client).unwrap(); + let span = tracer.start_span( + "test_span", + SpanKind::Client, + vec![Attribute { + key: "test_key".into(), + value: "test_value".into(), + }], + ); let azure_context = AzureContext::new(); let azure_context = azure_context.with_value(span.clone()); - let _guard = span.set_current(&azure_context).unwrap(); + let _guard = span.set_current(&azure_context); let result = future.await; assert_eq!(result, 42); - span.end().unwrap(); + span.end(); let spans = otel_exporter.get_finished_spans().unwrap(); assert_eq!(spans.len(), 1); for span in &spans { + trace!("Span: {:?}", span); assert_eq!(span.name, "test_span"); assert_eq!(span.events.len(), 1); - assert_eq!(span.attributes.len(), 1); + assert_eq!(span.attributes.len(), 2); assert_eq!(span.attributes[0].key, "test_key".into()); assert_eq!( format!("{:?}", span.attributes[0].value), - "String(Static(\"test_value\"))" + "String(Owned(\"test_value\"))" ); } } diff --git a/sdk/core/azure_core_opentelemetry/src/telemetry.rs b/sdk/core/azure_core_opentelemetry/src/telemetry.rs index ea7f1ebf57..55dc10a9f2 100644 --- a/sdk/core/azure_core_opentelemetry/src/telemetry.rs +++ b/sdk/core/azure_core_opentelemetry/src/telemetry.rs @@ -3,38 +3,81 @@ use crate::tracer::OpenTelemetryTracer; use azure_core::tracing::TracerProvider; -use azure_core::Result; use opentelemetry::{ global::{BoxedTracer, ObjectSafeTracerProvider}, InstrumentationScope, }; -use std::sync::Arc; +use std::{fmt::Debug, sync::Arc}; /// Enum to hold different OpenTelemetry tracer provider implementations. pub struct OpenTelemetryTracerProvider { - inner: Arc, + inner: Option>, } impl OpenTelemetryTracerProvider { /// Creates a new Azure telemetry provider with the given SDK tracer provider. - #[allow(dead_code)] - pub fn new(provider: Arc) -> Result { - Ok(Self { inner: provider }) + /// + /// # Arguments + /// - `provider`: An `Arc` to an object-safe tracer provider that implements the + /// `ObjectSafeTracerProvider` trait. + /// + /// # Returns + /// An `Arc` to the newly created `OpenTelemetryTracerProvider`. + /// + /// + pub fn new(provider: Arc) -> Arc { + Arc::new(Self { + inner: Some(provider), + }) + } + + /// Creates a new Azure telemetry provider that uses the global OpenTelemetry tracer provider. + /// + /// This is useful when you want to use the global OpenTelemetry provider without + /// explicitly instantiating a specific provider. + /// + /// # Returns + /// An `Arc` to the newly created `OpenTelemetryTracerProvider` that uses the global provider. + /// + pub fn new_from_global_provider() -> Arc { + Arc::new(Self { inner: None }) + } +} + +impl Debug for OpenTelemetryTracerProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpenTelemetryTracerProvider") + .finish_non_exhaustive() } } impl TracerProvider for OpenTelemetryTracerProvider { fn get_tracer( &self, - name: &'static str, - package_version: &'static str, - ) -> Box { - let scope = InstrumentationScope::builder(name) - .with_version(package_version) + namespace: Option<&'static str>, + crate_name: &'static str, + crate_version: Option<&'static str>, + ) -> Arc { + let mut builder = InstrumentationScope::builder(crate_name); + if let Some(crate_version) = crate_version { + builder = builder.with_version(crate_version); + } + let scope = builder + .with_schema_url("https://opentelemetry.io/schemas/1.23.0") .build(); - Box::new(OpenTelemetryTracer::new(BoxedTracer::new( - self.inner.boxed_tracer(scope), - ))) + if let Some(provider) = &self.inner { + // If we have a specific provider set, use it to create the tracer. + Arc::new(OpenTelemetryTracer::new( + namespace, + BoxedTracer::new(provider.boxed_tracer(scope)), + )) + } else { + // Use the global tracer if no specific provider has been set. + Arc::new(OpenTelemetryTracer::new( + namespace, + opentelemetry::global::tracer_with_scope(scope), + )) + } } } @@ -47,14 +90,27 @@ mod tests { #[test] fn test_create_tracer_provider_sdk_tracer() { let provider = Arc::new(SdkTracerProvider::builder().build()); - let tracer_provider = OpenTelemetryTracerProvider::new(provider); - assert!(tracer_provider.is_ok()); + let _tracer_provider = OpenTelemetryTracerProvider::new(provider); } #[test] fn test_create_tracer_provider_noop_tracer() { let provider = Arc::new(NoopTracerProvider::new()); - let tracer_provider = OpenTelemetryTracerProvider::new(provider); - assert!(tracer_provider.is_ok()); + let _tracer_provider = OpenTelemetryTracerProvider::new(provider); + } + + #[test] + fn test_create_tracer_provider_from_global() { + let tracer_provider = OpenTelemetryTracerProvider::new_from_global_provider(); + let _tracer = tracer_provider.get_tracer(Some("My Namespace"), "test", Some("0.1.0")); + } + + #[test] + fn test_create_tracer_provider_from_global_provider_set() { + let provider = SdkTracerProvider::builder().build(); + opentelemetry::global::set_tracer_provider(provider); + + let tracer_provider = OpenTelemetryTracerProvider::new_from_global_provider(); + let _tracer = tracer_provider.get_tracer(Some("My Namespace"), "test", Some("0.1.0")); } } diff --git a/sdk/core/azure_core_opentelemetry/src/tracer.rs b/sdk/core/azure_core_opentelemetry/src/tracer.rs index 9eb9e0fe7f..8b8ba1869d 100644 --- a/sdk/core/azure_core_opentelemetry/src/tracer.rs +++ b/sdk/core/azure_core_opentelemetry/src/tracer.rs @@ -1,80 +1,96 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use crate::span::{OpenTelemetrySpan, OpenTelemetrySpanKind}; -use azure_core::{ - tracing::{SpanKind, Tracer}, - Result, +use crate::{ + attributes::OpenTelemetryAttribute, + span::{OpenTelemetrySpan, OpenTelemetrySpanKind}, }; + +use azure_core::tracing::{SpanKind, Tracer}; use opentelemetry::{ global::BoxedTracer, trace::{TraceContextExt, Tracer as OpenTelemetryTracerTrait}, - Context, + Context, KeyValue, }; -use std::sync::Arc; +use std::{fmt::Debug, sync::Arc}; pub struct OpenTelemetryTracer { + namespace: Option<&'static str>, inner: BoxedTracer, } impl OpenTelemetryTracer { /// Creates a new OpenTelemetry tracer with the given inner tracer. - pub(super) fn new(tracer: BoxedTracer) -> Self { - Self { inner: tracer } + pub(super) fn new(namespace: Option<&'static str>, tracer: BoxedTracer) -> Self { + Self { + namespace, + inner: tracer, + } } } -impl Tracer for OpenTelemetryTracer { - fn start_span( - &self, - name: &'static str, - kind: SpanKind, - ) -> Result> { - let span_builder = opentelemetry::trace::SpanBuilder::from_name(name) - .with_kind(OpenTelemetrySpanKind(kind).into()); - let context = Context::new(); - let span = self.inner.build_with_context(span_builder, &context); +impl Debug for OpenTelemetryTracer { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpenTelemetryTracer") + .field("namespace", &self.namespace) + .finish_non_exhaustive() + } +} - Ok(OpenTelemetrySpan::new(context.with_span(span))) +impl Tracer for OpenTelemetryTracer { + fn namespace(&self) -> Option<&'static str> { + self.namespace } - fn start_span_with_current( + fn start_span( &self, name: &'static str, kind: SpanKind, - ) -> Result> { + attributes: Vec, + ) -> Arc { let span_builder = opentelemetry::trace::SpanBuilder::from_name(name) - .with_kind(OpenTelemetrySpanKind(kind).into()); + .with_kind(OpenTelemetrySpanKind(kind).into()) + .with_attributes( + attributes + .iter() + .map(|attr| KeyValue::from(OpenTelemetryAttribute(attr.clone()))), + ); let context = Context::current(); let span = self.inner.build_with_context(span_builder, &context); - Ok(OpenTelemetrySpan::new(context.with_span(span))) + OpenTelemetrySpan::new(context.with_span(span)) } fn start_span_with_parent( &self, name: &'static str, kind: SpanKind, - parent: Arc, - ) -> Result> { - let span_builder = opentelemetry::trace::SpanBuilder::from_name(name) - .with_kind(OpenTelemetrySpanKind(kind).into()); + attributes: Vec, + parent: Arc, + ) -> Arc { + let span_builder = opentelemetry::trace::SpanBuilder::from_name(name.to_owned()) + .with_kind(OpenTelemetrySpanKind(kind).into()) + .with_attributes( + attributes + .iter() + .map(|attr| KeyValue::from(OpenTelemetryAttribute(attr.clone()))), + ); // Cast the parent span to Any type let context = parent .as_any() .downcast_ref::() - .ok_or_else(|| { - azure_core::Error::message( - azure_core::error::ErrorKind::DataConversion, - "Could not downcast parent span to OpenTelemetrySpan", + .unwrap_or_else(|| { + panic!( + "Could not downcast parent span to OpenTelemetrySpan. Actual type: {}", + std::any::type_name::() ) - })? + }) .context() .clone(); let span = self.inner.build_with_context(span_builder, &context); - Ok(OpenTelemetrySpan::new(context.with_span(span))) + OpenTelemetrySpan::new(context.with_span(span)) } } @@ -89,24 +105,24 @@ mod tests { #[test] fn test_create_tracer() { let noop_tracer = NoopTracerProvider::new(); - let otel_provider = OpenTelemetryTracerProvider::new(Arc::new(noop_tracer)).unwrap(); - let tracer = otel_provider.get_tracer("test_tracer", "1.0.0"); - let span = tracer.start_span("test_span", SpanKind::Internal).unwrap(); - assert!(span.end().is_ok()); + let otel_provider = OpenTelemetryTracerProvider::new(Arc::new(noop_tracer)); + let tracer = otel_provider.get_tracer(Some("name"), "test_tracer", Some("1.0.0")); + let span = tracer.start_span("test_span", SpanKind::Internal, vec![]); + span.end(); } #[test] fn test_create_tracer_with_sdk_tracer() { let provider = SdkTracerProvider::builder().build(); - let otel_provider = OpenTelemetryTracerProvider::new(Arc::new(provider)).unwrap(); - let _tracer = otel_provider.get_tracer("test_tracer", "1.0.0"); + let otel_provider = OpenTelemetryTracerProvider::new(Arc::new(provider)); + let _tracer = otel_provider.get_tracer(Some("My.Namespace"), "test_tracer", Some("1.0.0")); } #[test] fn test_create_span_from_tracer() { let provider = SdkTracerProvider::builder().build(); - let otel_provider = OpenTelemetryTracerProvider::new(Arc::new(provider)).unwrap(); - let tracer = otel_provider.get_tracer("test_tracer", "1.0.0"); - let _span = tracer.start_span("test_span", SpanKind::Internal); + let otel_provider = OpenTelemetryTracerProvider::new(Arc::new(provider)); + let tracer = otel_provider.get_tracer(Some("My.Namespace"), "test_tracer", Some("1.0.0")); + let _span = tracer.start_span("test_span", SpanKind::Internal, vec![]); } } diff --git a/sdk/core/azure_core_opentelemetry/tests/integration_test.rs b/sdk/core/azure_core_opentelemetry/tests/otel_span_tests.rs similarity index 82% rename from sdk/core/azure_core_opentelemetry/tests/integration_test.rs rename to sdk/core/azure_core_opentelemetry/tests/otel_span_tests.rs index 65c17b31af..b0b4dc2000 100644 --- a/sdk/core/azure_core_opentelemetry/tests/integration_test.rs +++ b/sdk/core/azure_core_opentelemetry/tests/otel_span_tests.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -use azure_core::tracing::{SpanKind, TracerProvider}; +use azure_core::tracing::{SpanKind, TracerProvider as _}; use azure_core_opentelemetry::OpenTelemetryTracerProvider; use opentelemetry_sdk::trace::SdkTracerProvider; use std::error::Error; @@ -11,26 +11,26 @@ use std::sync::Arc; async fn test_span_creation() -> Result<(), Box> { // Set up a tracer provider for testing let sdk_provider = Arc::new(SdkTracerProvider::builder().build()); - let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider)?; + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); // Get a tracer from the Azure provider - let tracer = azure_provider.get_tracer("test_tracer", "1.0.0"); + let tracer = azure_provider.get_tracer(Some("test_namespace"), "test_tracer", Some("1.0.0")); // Create a span using the Azure tracer - let span = tracer.start_span("test_span", SpanKind::Internal).unwrap(); + let span = tracer.start_span("test_span", SpanKind::Internal, vec![]); // Add attributes to the span using individual set_attribute calls span.set_attribute( "test_key", azure_core::tracing::AttributeValue::String("test_value".to_string()), - )?; + ); span.set_attribute( "service.name", azure_core::tracing::AttributeValue::String("azure-test".to_string()), - )?; + ); // End the span - span.end()?; + span.end(); Ok(()) } @@ -39,12 +39,12 @@ async fn test_span_creation() -> Result<(), Box> { async fn test_tracer_provider_creation() -> Result<(), Box> { // Create multiple tracer provider instances to test initialization let sdk_provider = Arc::new(SdkTracerProvider::builder().build()); - let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider)?; + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); // Get a tracer and verify it works - let tracer = azure_provider.get_tracer("test_tracer", "1.0.0"); - let span = tracer.start_span("test_span", SpanKind::Internal).unwrap(); - span.end()?; + let tracer = azure_provider.get_tracer(Some("test.namespace"), "test_tracer", Some("1.0.0")); + let span = tracer.start_span("test_span", SpanKind::Internal, vec![]); + span.end(); Ok(()) } @@ -53,30 +53,30 @@ async fn test_tracer_provider_creation() -> Result<(), Box> { async fn test_span_attributes() -> Result<(), Box> { // Set up a tracer provider for testing let sdk_provider = Arc::new(SdkTracerProvider::builder().build()); - let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider)?; + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); // Get a tracer from the Azure provider - let tracer = azure_provider.get_tracer("test_tracer", "1.0.0"); + let tracer = azure_provider.get_tracer(Some("test.namespace"), "test_tracer", Some("1.0.0")); // Create span with multiple attributes - let span = tracer.start_span("test_span", SpanKind::Internal).unwrap(); + let span = tracer.start_span("test_span", SpanKind::Internal, vec![]); // Add attributes using individual set_attribute calls span.set_attribute( "service.name", azure_core::tracing::AttributeValue::String("test-service".to_string()), - )?; + ); span.set_attribute( "operation.name", azure_core::tracing::AttributeValue::String("test-operation".to_string()), - )?; + ); span.set_attribute( "request.id", azure_core::tracing::AttributeValue::String("req-123".to_string()), - )?; + ); // End the span - span.end()?; + span.end(); Ok(()) } diff --git a/sdk/core/azure_core_opentelemetry/tests/telemetry_service_implementation.rs b/sdk/core/azure_core_opentelemetry/tests/telemetry_service_implementation.rs new file mode 100644 index 0000000000..1cc0be4ef9 --- /dev/null +++ b/sdk/core/azure_core_opentelemetry/tests/telemetry_service_implementation.rs @@ -0,0 +1,516 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// cspell: ignore azuresdkforcpp azurewebsites + +//! This file contains an Azure SDK for Rust fake service client API. +//! +use azure_core::{ + credentials::TokenCredential, + fmt::SafeDebug, + http::{ + ClientMethodOptions, ClientOptions, Pipeline, RawResponse, Request, + RequestInstrumentationOptions, Url, + }, + tracing::{PublicApiInstrumentationInformation, Tracer}, + Result, +}; +use azure_core_opentelemetry::OpenTelemetryTracerProvider; +use opentelemetry_sdk::trace::{InMemorySpanExporter, SdkTracerProvider}; +use std::sync::Arc; + +#[derive(Clone, SafeDebug)] +pub struct TestServiceClientOptions { + pub azure_client_options: ClientOptions, + pub api_version: Option, +} + +impl Default for TestServiceClientOptions { + fn default() -> Self { + Self { + azure_client_options: ClientOptions::default(), + api_version: Some("2023-10-01".to_string()), + } + } +} + +pub struct TestServiceClient { + endpoint: Url, + api_version: String, + pipeline: Pipeline, + tracer: Option>, +} + +#[derive(Default, SafeDebug)] +pub struct TestServiceClientGetMethodOptions<'a> { + pub method_options: ClientMethodOptions<'a>, +} + +impl TestServiceClient { + pub fn new( + endpoint: &str, + _credential: Arc, + options: Option, + ) -> Result { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + + let tracer = + if let Some(tracer_options) = &options.azure_client_options.request_instrumentation { + tracer_options + .tracer_provider + .as_ref() + .map(|tracer_provider| { + tracer_provider.get_tracer( + Some("Az.TestServiceClient"), + option_env!("CARGO_PKG_NAME").unwrap_or("UNKNOWN"), + option_env!("CARGO_PKG_VERSION"), + ) + }) + } else { + None + }; + + Ok(Self { + endpoint, + api_version: options.api_version.unwrap_or_default(), + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.azure_client_options, + Vec::default(), + Vec::default(), + ), + tracer, + }) + } + + /// Returns the Url associated with this client. + pub fn endpoint(&self) -> &Url { + &self.endpoint + } + + /// Returns the result of a Get verb against the configured endpoint with the specified path. + /// + /// This method demonstrates a service client which does not have per-method spans but which will create + /// HTTP client spans if the `RequestInstrumentationOptions` are configured in the client options. + /// + pub async fn get( + &self, + path: &str, + options: Option>, + ) -> Result { + let options = options.unwrap_or_default(); + let mut url = self.endpoint.clone(); + url.set_path(path); + url.query_pairs_mut() + .append_pair("api-version", &self.api_version); + + let mut request = Request::new(url, azure_core::http::Method::Get); + + let response = self + .pipeline + .send(&options.method_options.context, &mut request) + .await?; + if !response.status().is_success() { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + }, + format!("Failed to GET {}: {}", request.url(), response.status()), + )); + } + Ok(response) + } + + /// Returns the result of a Get verb against the configured endpoint with the specified path. + /// + /// This method demonstrates a service client which has per-method spans and uses the configured tracing + /// tracing provider to create per-api spans for the function. + /// + /// To configure per-api spans, your service implementation needs to do the following: + /// 1. If the client is configured with a [`Tracer`], it will create a span whose name matches the function. + /// 1. The span should be created with the `SpanKind::Internal` kind, and + /// 2. The span should have the `az.namespace` attribute set to the namespace of the service client. + /// 2. The function should add the span created in step 1 to the ClientMethodOptions context. + /// 3. The function should add the tracer to the ClientMethodOptions context so that the pipeline can use it to populate the `az.namespace` property in the request span. + /// 4. The function should then perform the normal client operations after setting up the context. + /// 5. After the client operation completes, if the function failed, it should add an `error.type` attribute to the span + /// with the error type. + /// + /// # Note + /// This applies to most HTTP client operations, but not all. CosmosDB has its own set of conventions as listed + /// [here](https://github.com/open-telemetry/semantic-conventions/blob/main/docs/database/cosmosdb.md) + /// + pub async fn get_with_function_tracing( + &self, + path: &str, + options: Option>, + ) -> Result { + let mut options = options.unwrap_or_default(); + + let public_api_info = PublicApiInstrumentationInformation { + api_name: "get_with_tracing", + attributes: vec![], + }; + // Add the span to the tracer. + let mut ctx = options.method_options.context.with_value(public_api_info); + // If the service has a tracer, we add it to the context. + if let Some(tracer) = &self.tracer { + ctx = ctx.with_value(tracer.clone()); + } + options.method_options.context = ctx; + self.get(path, Some(options)).await + } +} + +use azure_core_test::{recorded, TestContext}; +use opentelemetry::trace::{SpanKind as OpenTelemetrySpanKind, Status as OpenTelemetrySpanStatus}; +use opentelemetry::Value as OpenTelemetryAttributeValue; +use tracing::{info, trace}; + +fn create_exportable_tracer_provider() -> (Arc, InMemorySpanExporter) { + let otel_exporter = InMemorySpanExporter::default(); + let otel_tracer_provider = SdkTracerProvider::builder() + .with_simple_exporter(otel_exporter.clone()) + .build(); + let otel_tracer_provider = Arc::new(otel_tracer_provider); + (otel_tracer_provider, otel_exporter) +} + +// Span verification utility functions. + +struct ExpectedSpan { + name: &'static str, + kind: OpenTelemetrySpanKind, + parent_span_id: Option, + status: OpenTelemetrySpanStatus, + attributes: Vec<(&'static str, OpenTelemetryAttributeValue)>, +} + +fn verify_span(span: &opentelemetry_sdk::trace::SpanData, expected: ExpectedSpan) -> Result<()> { + assert_eq!(span.name, expected.name); + assert_eq!(span.span_kind, expected.kind); + assert_eq!(span.status, expected.status); + assert_eq!( + span.parent_span_id, + expected + .parent_span_id + .unwrap_or(opentelemetry::trace::SpanId::INVALID) + ); + + for attr in span.attributes.iter() { + println!("Attribute: {} = {:?}", attr.key, attr.value); + let mut found = false; + for (key, value) in expected.attributes.iter() { + if attr.key.as_str() == (*key) { + found = true; + // Skip checking the value for "" as it is a placeholder + if *value != OpenTelemetryAttributeValue::String("".into()) { + assert_eq!(attr.value, *value, "Attribute mismatch for key: {}", *key); + } + break; + } + } + if !found { + panic!("Unexpected attribute: {} = {:?}", attr.key, attr.value); + } + } + for (key, value) in expected.attributes.iter() { + if !span.attributes.iter().any(|attr| attr.key == (*key).into()) { + panic!("Expected attribute not found: {} = {:?}", key, value); + } + } + + Ok(()) +} + +// Basic functionality tests. +#[recorded::test()] +async fn test_service_client_new(ctx: TestContext) -> Result<()> { + let recording = ctx.recording(); + let endpoint = "https://www.microsoft.com"; + let credential = recording.credential().clone(); + let options = TestServiceClientOptions { + ..Default::default() + }; + + let client = TestServiceClient::new(endpoint, credential, Some(options)).unwrap(); + assert_eq!(client.endpoint().as_str(), "https://www.microsoft.com/"); + assert_eq!(client.api_version, "2023-10-01"); + + Ok(()) +} + +// Ensure that the the test client actually does what it's supposed to do without telemetry. +#[recorded::test()] +async fn test_service_client_get(ctx: TestContext) -> Result<()> { + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.net"; + let credential = recording.credential().clone(); + + let client = TestServiceClient::new(endpoint, credential, None).unwrap(); + let response = client.get("get", None).await; + info!("Response: {:?}", response); + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.status(), azure_core::http::StatusCode::Ok); + Ok(()) +} + +#[recorded::test()] +async fn test_service_client_get_with_tracing(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.net"; + let credential = recording.credential().clone(); + let options = TestServiceClientOptions { + azure_client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, + ..Default::default() + }; + + let client = TestServiceClient::new(endpoint, credential, Some(options)).unwrap(); + let response = client.get("get", None).await; + info!("Response: {:?}", response); + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.status(), azure_core::http::StatusCode::Ok); + + let spans = otel_exporter.get_finished_spans().unwrap(); + for (i, span) in spans.iter().enumerate() { + trace!("Span {i}: {span:?}"); + } + assert_eq!(spans.len(), 1); + verify_span( + &spans[0], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + status: OpenTelemetrySpanStatus::Unset, + parent_span_id: None, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!("{}{}", client.endpoint(), "get?api-version=2023-10-01").into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("http.response.status_code", 200.into()), + ], + }, + )?; + + Ok(()) +} + +#[recorded::test()] +async fn test_service_client_get_tracing_error(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.net"; + let credential = recording.credential().clone(); + let options = TestServiceClientOptions { + azure_client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, + ..Default::default() + }; + + let client = TestServiceClient::new(endpoint, credential, Some(options)).unwrap(); + let response = client.get("failing_url", None).await; + info!("Response: {:?}", response); + + let spans = otel_exporter.get_finished_spans().unwrap(); + for (i, span) in spans.iter().enumerate() { + trace!("Span {i}: {span:?}"); + } + assert_eq!(spans.len(), 1); + + verify_span( + &spans[0], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: None, + status: OpenTelemetrySpanStatus::Error { + description: "".into(), + }, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("error.type", "404".into()), + ("http.response.status_code", 404.into()), + ], + }, + )?; + + Ok(()) +} + +#[recorded::test()] +async fn test_service_client_get_with_function_tracing(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.net"; + let credential = recording.credential().clone(); + let options = TestServiceClientOptions { + azure_client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, + ..Default::default() + }; + + let client = TestServiceClient::new(endpoint, credential, Some(options)).unwrap(); + let response = client.get_with_function_tracing("get", None).await; + info!("Response: {:?}", response); + + let spans = otel_exporter.get_finished_spans().unwrap(); + for (i, span) in spans.iter().enumerate() { + trace!("Span {i}: {span:?}"); + } + assert_eq!(spans.len(), 2); + verify_span( + &spans[0], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[1].span_context.span_id()), + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!("{}{}", client.endpoint(), "get?api-version=2023-10-01").into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("http.response.status_code", 200.into()), + ], + }, + )?; + verify_span( + &spans[1], + ExpectedSpan { + name: "get_with_tracing", + kind: OpenTelemetrySpanKind::Internal, + parent_span_id: None, + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![("az.namespace", "Az.TestServiceClient".into())], + }, + )?; + + Ok(()) +} + +#[recorded::test()] +async fn test_service_client_get_with_function_tracing_error(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.net"; + let credential = recording.credential().clone(); + let options = TestServiceClientOptions { + azure_client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, + ..Default::default() + }; + + let client = TestServiceClient::new(endpoint, credential, Some(options)).unwrap(); + let response = client.get_with_function_tracing("failing_url", None).await; + info!("Response: {:?}", response); + + let spans = otel_exporter.get_finished_spans().unwrap(); + for (i, span) in spans.iter().enumerate() { + trace!("Span {i}: {span:?}"); + } + assert_eq!(spans.len(), 2); + verify_span( + &spans[0], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[1].span_context.span_id()), + status: OpenTelemetrySpanStatus::Error { + description: "".into(), + }, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("http.response.status_code", 404.into()), + ("error.type", "404".into()), + ], + }, + )?; + verify_span( + &spans[1], + ExpectedSpan { + name: "get_with_tracing", + kind: OpenTelemetrySpanKind::Internal, + parent_span_id: None, + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("az.namespace", "Az.TestServiceClient".into()), + ("error.type", "404".into()), + ], + }, + )?; + + Ok(()) +} diff --git a/sdk/core/azure_core_opentelemetry/tests/telemetry_service_macros.rs b/sdk/core/azure_core_opentelemetry/tests/telemetry_service_macros.rs new file mode 100644 index 0000000000..1561fd5e8c --- /dev/null +++ b/sdk/core/azure_core_opentelemetry/tests/telemetry_service_macros.rs @@ -0,0 +1,678 @@ +// Copyright (C) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// cspell: ignore azuresdkforcpp invalidtopleveldomain azurewebsites +//! This file contains an Azure SDK for Rust fake service client API. +//! +use azure_core::{ + credentials::TokenCredential, + fmt::SafeDebug, + http::{ + ClientMethodOptions, ClientOptions, Pipeline, RawResponse, Request, + RequestInstrumentationOptions, Url, + }, + tracing, Result, +}; +use azure_core_opentelemetry::OpenTelemetryTracerProvider; +use opentelemetry_sdk::trace::{InMemorySpanExporter, SdkTracerProvider}; +use std::sync::Arc; + +#[derive(Clone, SafeDebug)] +pub struct TestServiceClientWithMacrosOptions { + pub client_options: ClientOptions, + pub api_version: Option, +} + +impl Default for TestServiceClientWithMacrosOptions { + fn default() -> Self { + Self { + client_options: ClientOptions::default(), + api_version: Some("2023-10-01".to_string()), + } + } +} + +/// Define a TestServiceClient which is a fake service client for testing purposes. +/// This client demonstrates how to implement a service client using the tracing convenience proc macros. +#[tracing::client] +pub struct TestServiceClientWithMacros { + endpoint: Url, + api_version: String, + pipeline: Pipeline, +} + +#[derive(Default, SafeDebug)] +pub struct TestServiceClientWithMacrosGetMethodOptions<'a> { + pub method_options: ClientMethodOptions<'a>, +} + +impl TestServiceClientWithMacros { + /// Creates a new instance of the TestServiceClient. + /// + /// This function demonstrates how to create a service client using the tracing convenience proc macros. + /// + /// # Arguments + /// * `endpoint` - The endpoint URL for the service. + /// * `_credential` - The credential used for authentication (not used in this example). + /// * `options` - Optional client options to configure the client. + /// + #[tracing::new("Az.TestServiceClient")] + pub fn new( + endpoint: &str, + _credential: Arc, + options: Option, + ) -> Result { + let options = options.unwrap_or_default(); + let mut endpoint = Url::parse(endpoint)?; + if !endpoint.scheme().starts_with("http") { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::Other, + format!("{endpoint} must use http(s)"), + )); + } + endpoint.set_query(None); + + Ok(Self { + endpoint, + api_version: options.api_version.unwrap_or_default(), + pipeline: Pipeline::new( + option_env!("CARGO_PKG_NAME"), + option_env!("CARGO_PKG_VERSION"), + options.client_options, + Vec::default(), + Vec::default(), + ), + }) + } + + /// Returns the Url associated with this client. + pub fn endpoint(&self) -> &Url { + &self.endpoint + } + + pub async fn get( + &self, + path: &str, + options: Option>, + ) -> Result { + let options = options.unwrap_or_default(); + let mut url = self.endpoint.clone(); + url.set_path(path); + url.query_pairs_mut() + .append_pair("api-version", &self.api_version); + + let mut request = Request::new(url, azure_core::http::Method::Get); + + let response = self + .pipeline + .send(&options.method_options.context, &mut request) + .await?; + if !response.status().is_success() { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + }, + format!("Failed to GET {}: {}", request.url(), response.status()), + )); + } + Ok(response) + } + + /// Returns the result of a Get verb against the configured endpoint with the specified path. + /// + /// This method demonstrates a service client which has per-method spans and uses the configured tracing + /// tracing provider to create per-api spans for the function. + /// + /// To configure per-api spans, your service implementation needs to do the following: + /// 1. If the client is configured with a [`Tracer`], it will create a span whose name matches the function. + /// 1. The span should be created with the `SpanKind::Internal` kind, and + /// 2. The span should have the `az.namespace` attribute set to the namespace of the service client. + /// 2. The function should add the span created in step 1 to the ClientMethodOptions context. + /// 3. The function should add the tracer to the ClientMethodOptions context so that the pipeline can use it to populate the `az.namespace` property in the request span. + /// 4. The function should then perform the normal client operations after setting up the context. + /// 5. After the client operation completes, if the function failed, it should add an `error.type` attribute to the span + /// with the error type. + /// + /// # Note + /// This applies to most HTTP client operations, but not all. CosmosDB has its own set of conventions as listed + /// [here](https://github.com/open-telemetry/semantic-conventions/blob/main/docs/database/cosmosdb.md) + /// + #[tracing::function("macros_get_with_tracing",(a.b=1,az.telemetry="Abc","string attribute"=path))] + pub async fn get_with_function_tracing( + &self, + path: &str, + options: Option>, + ) -> Result { + let options = options.unwrap_or_default(); + + let mut url = self.endpoint.clone(); + url.set_path(path); + url.query_pairs_mut() + .append_pair("api-version", &self.api_version); + + let mut request = Request::new(url, azure_core::http::Method::Get); + + let response = self + .pipeline + .send(&options.method_options.context, &mut request) + .await?; + if !response.status().is_success() { + return Err(azure_core::Error::message( + azure_core::error::ErrorKind::HttpResponse { + status: response.status(), + error_code: None, + }, + format!("Failed to GET {}: {}", request.url(), response.status()), + )); + } + Ok(response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ::tracing::{info, trace}; + use azure_core::http::{ExponentialRetryOptions, RetryOptions}; + use azure_core::tracing::TracerProvider; + use azure_core::Result; + use azure_core_test::{recorded, TestContext}; + use opentelemetry::trace::{ + SpanKind as OpenTelemetrySpanKind, Status as OpenTelemetrySpanStatus, + }; + use opentelemetry::Value as OpenTelemetryAttributeValue; + + fn create_exportable_tracer_provider() -> (Arc, InMemorySpanExporter) { + let otel_exporter = InMemorySpanExporter::default(); + let otel_tracer_provider = SdkTracerProvider::builder() + .with_simple_exporter(otel_exporter.clone()) + .build(); + let otel_tracer_provider = Arc::new(otel_tracer_provider); + (otel_tracer_provider, otel_exporter) + } + + fn create_service_client( + ctx: TestContext, + azure_provider: Arc, + ) -> TestServiceClientWithMacros { + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.net"; + let credential = recording.credential().clone(); + let options = TestServiceClientWithMacrosOptions { + client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, + ..Default::default() + }; + + TestServiceClientWithMacros::new(endpoint, credential, Some(options)).unwrap() + } + + // Span verification utility functions. + + struct ExpectedSpan { + name: &'static str, + kind: OpenTelemetrySpanKind, + parent_span_id: Option, + status: OpenTelemetrySpanStatus, + attributes: Vec<(&'static str, OpenTelemetryAttributeValue)>, + } + + fn verify_span( + span: &opentelemetry_sdk::trace::SpanData, + expected: ExpectedSpan, + ) -> Result<()> { + assert_eq!(span.name, expected.name); + assert_eq!(span.span_kind, expected.kind); + assert_eq!(span.status, expected.status); + assert_eq!( + span.parent_span_id, + expected + .parent_span_id + .unwrap_or(opentelemetry::trace::SpanId::INVALID) + ); + + for attr in span.attributes.iter() { + println!("Attribute: {} = {:?}", attr.key, attr.value); + let mut found = false; + for (key, value) in expected.attributes.iter() { + if attr.key.as_str() == (*key) { + found = true; + // Skip checking the value for "" as it is a placeholder + if *value != OpenTelemetryAttributeValue::String("".into()) { + assert_eq!(attr.value, *value, "Attribute mismatch for key: {}", *key); + } + break; + } + } + if !found { + panic!("Unexpected attribute: {} = {:?}", attr.key, attr.value); + } + } + for (key, value) in expected.attributes.iter() { + if !span.attributes.iter().any(|attr| attr.key == (*key).into()) { + panic!("Expected attribute not found: {} = {:?}", key, value); + } + } + + Ok(()) + } + + // Basic functionality tests. + #[recorded::test()] + async fn test_macro_service_client_new(ctx: TestContext) -> Result<()> { + let recording = ctx.recording(); + let endpoint = "https://microsoft.com"; + let credential = recording.credential().clone(); + let options = TestServiceClientWithMacrosOptions { + ..Default::default() + }; + + let client = TestServiceClientWithMacros::new(endpoint, credential, Some(options)).unwrap(); + assert_eq!(client.endpoint().as_str(), "https://microsoft.com/"); + assert_eq!(client.api_version, "2023-10-01"); + + Ok(()) + } + + #[recorded::test()] + async fn test_macro_service_client_get(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let client = create_service_client(ctx, azure_provider.clone()); + + let response = client.get("get", None).await; + info!("Response: {:?}", response); + assert!(response.is_ok()); + let response = response.unwrap(); + assert_eq!(response.status(), azure_core::http::StatusCode::Ok); + + let spans = otel_exporter.get_finished_spans().unwrap(); + assert_eq!(spans.len(), 1); + for span in &spans { + trace!("Span: {:?}", span); + + verify_span( + span, + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + status: OpenTelemetrySpanStatus::Unset, + parent_span_id: None, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!("{}{}", client.endpoint(), "get?api-version=2023-10-01").into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("http.response.status_code", 200.into()), + ], + }, + )?; + } + + Ok(()) + } + + #[recorded::test()] + async fn test_macro_service_client_get_with_error(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let client = create_service_client(ctx, azure_provider.clone()); + + let response = client.get("failing_url", None).await; + info!("Response: {:?}", response); + + let spans = otel_exporter.get_finished_spans().unwrap(); + assert_eq!(spans.len(), 1); + for span in &spans { + trace!("Span: {:?}", span); + + verify_span( + span, + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: None, + status: OpenTelemetrySpanStatus::Error { + description: "".into(), + }, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("error.type", "404".into()), + ("http.response.status_code", 404.into()), + ], + }, + )?; + } + + Ok(()) + } + + #[recorded::test()] + async fn test_macro_service_client_get_with_function_tracing(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let client = create_service_client(ctx, azure_provider.clone()); + + let response = client.get_with_function_tracing("get", None).await; + info!("Response: {:?}", response); + + let spans = otel_exporter.get_finished_spans().unwrap(); + assert_eq!(spans.len(), 2); + for span in &spans { + trace!("Span: {:?}", span); + } + verify_span( + &spans[0], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[1].span_context.span_id()), + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!("{}{}", client.endpoint(), "get?api-version=2023-10-01").into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("http.response.status_code", 200.into()), + ], + }, + )?; + verify_span( + &spans[1], + ExpectedSpan { + name: "macros_get_with_tracing", + kind: OpenTelemetrySpanKind::Internal, + parent_span_id: None, + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("az.namespace", "Az.TestServiceClient".into()), + ("a.b", 1.into()), // added by tracing macro. + ("az.telemetry", "Abc".into()), // added by tracing macro + ("string attribute", "get".into()), // added by tracing macro. + ], + }, + )?; + + Ok(()) + } + + #[recorded::test()] + async fn test_macro_service_client_get_function_tracing_error(ctx: TestContext) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.net"; + let credential = recording.credential().clone(); + let options = TestServiceClientWithMacrosOptions { + client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + ..Default::default() + }, + ..Default::default() + }; + + let client = TestServiceClientWithMacros::new(endpoint, credential, Some(options)).unwrap(); + let response = client.get_with_function_tracing("failing_url", None).await; + info!("Response: {:?}", response); + + let spans = otel_exporter.get_finished_spans().unwrap(); + assert_eq!(spans.len(), 2); + for span in &spans { + trace!("Span: {:?}", span); + } + verify_span( + &spans[0], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[1].span_context.span_id()), + status: OpenTelemetrySpanStatus::Error { + description: "".into(), + }, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ("server.address", "azuresdkforcpp.azurewebsites.net".into()), + ("server.port", 443.into()), + ("http.response.status_code", 404.into()), + ("error.type", "404".into()), + ], + }, + )?; + verify_span( + &spans[1], + ExpectedSpan { + name: "macros_get_with_tracing", + kind: OpenTelemetrySpanKind::Internal, + parent_span_id: None, + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("az.namespace", "Az.TestServiceClient".into()), + ("error.type", "404".into()), + ("a.b", 1.into()), // added by tracing macro. + ("az.telemetry", "Abc".into()), // added by tracing macro + ("string attribute", "failing_url".into()), // added by tracing macro. + ], + }, + )?; + + Ok(()) + } + + #[recorded::test()] + async fn test_macro_service_client_get_function_tracing_dns_error( + ctx: TestContext, + ) -> Result<()> { + let (sdk_provider, otel_exporter) = create_exportable_tracer_provider(); + let azure_provider = OpenTelemetryTracerProvider::new(sdk_provider); + + let recording = ctx.recording(); + let endpoint = "https://azuresdkforcpp.azurewebsites.invalidtopleveldomain"; + let credential = recording.credential().clone(); + let options = TestServiceClientWithMacrosOptions { + client_options: ClientOptions { + request_instrumentation: Some(RequestInstrumentationOptions { + tracer_provider: Some(azure_provider), + }), + retry: Some(RetryOptions::exponential(ExponentialRetryOptions { + max_retries: 3, + ..Default::default() + })), + ..Default::default() + }, + ..Default::default() + }; + + let client = TestServiceClientWithMacros::new(endpoint, credential, Some(options)).unwrap(); + let response = client.get_with_function_tracing("failing_url", None).await; + info!("Response: {:?}", response); + + let spans = otel_exporter.get_finished_spans().unwrap(); + assert_eq!(spans.len(), 5); + for span in &spans { + trace!("Span: {:?}", span); + } + verify_span( + &spans[0], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[4].span_context.span_id()), + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ( + "server.address", + "azuresdkforcpp.azurewebsites.invalidtopleveldomain".into(), + ), + ("server.port", 443.into()), + ("error.type", "Io".into()), + ], + }, + )?; + verify_span( + &spans[1], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[4].span_context.span_id()), + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ( + "server.address", + "azuresdkforcpp.azurewebsites.invalidtopleveldomain".into(), + ), + ("server.port", 443.into()), + ("http.request.resend_count", 1.into()), + ("error.type", "Io".into()), + ], + }, + )?; + verify_span( + &spans[2], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[4].span_context.span_id()), + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ( + "server.address", + "azuresdkforcpp.azurewebsites.invalidtopleveldomain".into(), + ), + ("server.port", 443.into()), + ("http.request.resend_count", 2.into()), + ("error.type", "Io".into()), + ], + }, + )?; + verify_span( + &spans[3], + ExpectedSpan { + name: "GET", + kind: OpenTelemetrySpanKind::Client, + parent_span_id: Some(spans[4].span_context.span_id()), + status: OpenTelemetrySpanStatus::Unset, + attributes: vec![ + ("http.request.method", "GET".into()), + ("az.namespace", "Az.TestServiceClient".into()), + ("az.client_request_id", "".into()), + ( + "url.full", + format!( + "{}{}", + client.endpoint(), + "failing_url?api-version=2023-10-01" + ) + .into(), + ), + ( + "server.address", + "azuresdkforcpp.azurewebsites.invalidtopleveldomain".into(), + ), + ("server.port", 443.into()), + ("http.request.resend_count", 3.into()), + ("error.type", "Io".into()), + ], + }, + )?; + + verify_span( + &spans[4], + ExpectedSpan { + name: "macros_get_with_tracing", + kind: OpenTelemetrySpanKind::Internal, + parent_span_id: None, + status: OpenTelemetrySpanStatus::Error { + description: "Io".into(), + }, + attributes: vec![ + ("az.namespace", "Az.TestServiceClient".into()), + ("error.type", "Io".into()), + ("a.b", 1.into()), // added by tracing macro. + ("az.telemetry", "Abc".into()), // added by tracing macro + ("string attribute", "failing_url".into()), // added by tracing macro. + ], + }, + )?; + + Ok(()) + } +} diff --git a/sdk/core/azure_core_test/src/lib.rs b/sdk/core/azure_core_test/src/lib.rs index 2b7ac5f2da..9f86b4b83b 100644 --- a/sdk/core/azure_core_test/src/lib.rs +++ b/sdk/core/azure_core_test/src/lib.rs @@ -11,6 +11,7 @@ mod recording; #[cfg(doctest)] mod root_readme; pub mod stream; +pub mod tracing; use azure_core::Error; pub use azure_core::{error::ErrorKind, test::TestMode}; @@ -177,7 +178,7 @@ impl TestContext { /// * `cargo_dir` - The directory of the Cargo package, typically the value of the `CARGO_MANIFEST_DIR` environment variable. pub fn load_dotenv_file(cargo_dir: impl AsRef) -> azure_core::Result<()> { if let Ok(path) = find_ancestor_file(cargo_dir, ".env") { - tracing::debug!("loading environment variables from {}", path.display()); + ::tracing::debug!("loading environment variables from {}", path.display()); use azure_core::error::ResultExt as _; dotenvy::from_filename(&path).with_context(azure_core::error::ErrorKind::Io, || { diff --git a/sdk/core/azure_core_test/src/tracing.rs b/sdk/core/azure_core_test/src/tracing.rs new file mode 100644 index 0000000000..f2cfca01e6 --- /dev/null +++ b/sdk/core/azure_core_test/src/tracing.rs @@ -0,0 +1,255 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// cspell: ignore traceparent +use std::sync::{Arc, Mutex}; +use tracing::trace; +use typespec_client_core::{ + http::{headers::HeaderName, Context, Request}, + tracing::{ + AsAny, Attribute, AttributeValue, Span, SpanKind, SpanStatus, Tracer, TracerProvider, + }, +}; + +#[derive(Debug)] +pub struct MockTracingProvider { + tracers: Mutex>>, +} + +impl MockTracingProvider { + pub fn new() -> Self { + Self { + tracers: Mutex::new(Vec::new()), + } + } +} + +impl Default for MockTracingProvider { + fn default() -> Self { + Self::new() + } +} + +impl TracerProvider for MockTracingProvider { + fn get_tracer( + &self, + azure_namespace: Option<&'static str>, + crate_name: &'static str, + crate_version: Option<&'static str>, + ) -> Arc { + let mut tracers = self.tracers.lock().unwrap(); + let tracer = Arc::new(MockTracer { + namespace: azure_namespace, + package_name: crate_name, + package_version: crate_version, + spans: Mutex::new(Vec::new()), + }); + + tracers.push(tracer.clone()); + tracer + } +} + +#[derive(Debug)] +pub struct MockTracer { + pub namespace: Option<&'static str>, + pub package_name: &'static str, + pub package_version: Option<&'static str>, + pub spans: Mutex>>, +} + +impl Tracer for MockTracer { + fn namespace(&self) -> Option<&'static str> { + self.namespace + } + + fn start_span_with_parent( + &self, + name: &str, + kind: SpanKind, + attributes: Vec, + _parent: Arc, + ) -> Arc { + let span = Arc::new(MockSpan::new(name, kind, attributes.clone())); + self.spans.lock().unwrap().push(span.clone()); + span + } + + fn start_span( + &self, + name: &'static str, + kind: SpanKind, + attributes: Vec, + ) -> Arc { + let attributes = attributes + .into_iter() + .map(|attr| Attribute { + key: attr.key.clone(), + value: attr.value.clone(), + }) + .collect(); + let span = Arc::new(MockSpan::new(name, kind, attributes)); + self.spans.lock().unwrap().push(span.clone()); + span + } +} + +#[derive(Debug)] +pub struct MockSpan { + pub name: String, + pub kind: SpanKind, + pub attributes: Mutex>, + pub state: Mutex, + pub is_open: Mutex, +} +impl MockSpan { + fn new(name: &str, kind: SpanKind, attributes: Vec) -> Self { + println!("Creating MockSpan: {}", name); + println!("Attributes: {:?}", attributes); + println!("Converted attributes: {:?}", attributes); + Self { + name: name.to_string(), + kind, + attributes: Mutex::new(attributes), + state: Mutex::new(SpanStatus::Unset), + is_open: Mutex::new(true), + } + } +} + +impl Span for MockSpan { + fn set_attribute(&self, key: &'static str, value: AttributeValue) { + println!("{}: Setting attribute {}: {:?}", self.name, key, value); + let mut attributes = self.attributes.lock().unwrap(); + attributes.push(Attribute { + key: key.into(), + value, + }); + } + + fn set_status(&self, status: crate::tracing::SpanStatus) { + println!("{}: Setting span status: {:?}", self.name, status); + let mut state = self.state.lock().unwrap(); + *state = status; + } + + fn end(&self) { + println!("Ending span: {}", self.name); + let mut is_open = self.is_open.lock().unwrap(); + *is_open = false; + } + + fn is_recording(&self) -> bool { + true + } + + fn span_id(&self) -> [u8; 8] { + [0; 8] // Mock span ID + } + + fn record_error(&self, _error: &dyn std::error::Error) { + todo!() + } + + fn set_current(&self, _context: &Context) -> Box { + todo!() + } + + /// Insert two dummy headers for distributed tracing. + // cspell: ignore traceparent tracestate + fn propagate_headers(&self, request: &mut Request) { + request.insert_header( + HeaderName::from_static("traceparent"), + "00---01", + ); + request.insert_header(HeaderName::from_static("tracestate"), "="); + } +} + +impl AsAny for MockSpan { + fn as_any(&self) -> &dyn std::any::Any { + // Convert to an object that doesn't expose the lifetime parameter + // We're essentially erasing the lifetime here to satisfy the static requirement + self as &dyn std::any::Any + } +} + +#[derive(Debug)] +pub struct ExpectedTracerInformation<'a> { + pub name: &'a str, + pub version: Option<&'a str>, + pub namespace: Option<&'a str>, + pub spans: Vec>, +} + +#[derive(Debug)] +pub struct ExpectedSpanInformation<'a> { + pub span_name: &'a str, + pub status: SpanStatus, + pub kind: SpanKind, + pub attributes: Vec<(&'a str, AttributeValue)>, +} + +pub fn check_instrumentation_result( + mock_tracer: Arc, + expected_tracers: Vec>, +) { + assert_eq!( + mock_tracer.tracers.lock().unwrap().len(), + expected_tracers.len(), + "Unexpected number of tracers", + ); + let tracers = mock_tracer.tracers.lock().unwrap(); + for (index, expected) in expected_tracers.iter().enumerate() { + trace!("Checking tracer {}: {}", index, expected.name); + let tracer = &tracers[index]; + assert_eq!(tracer.package_name, expected.name); + assert_eq!(tracer.package_version, expected.version); + assert_eq!(tracer.namespace, expected.namespace); + + let spans = tracer.spans.lock().unwrap(); + assert_eq!( + spans.len(), + expected.spans.len(), + "Unexpected number of spans for tracer {}", + expected.name + ); + + for (span_index, span_expected) in expected.spans.iter().enumerate() { + println!( + "Checking span {} of tracer {}: {}", + span_index, expected.name, span_expected.span_name + ); + check_span_information(&spans[span_index], span_expected); + } + } +} + +fn check_span_information(span: &Arc, expected: &ExpectedSpanInformation<'_>) { + assert_eq!(span.name, expected.span_name); + assert_eq!(span.kind, expected.kind); + assert_eq!(*span.state.lock().unwrap(), expected.status); + let attributes = span.attributes.lock().unwrap(); + for (index, attr) in attributes.iter().enumerate() { + println!("Attribute {}: {} = {:?}", index, attr.key, attr.value); + let mut found = false; + for (key, value) in &expected.attributes { + if attr.key == *key { + assert_eq!(attr.value, *value, "Attribute mismatch for key: {}", key); + found = true; + break; + } + } + if !found { + panic!("Unexpected attribute: {} = {:?}", attr.key, attr.value); + } + } + for (key, value) in expected.attributes.iter() { + if !attributes + .iter() + .any(|attr| attr.key == *key && attr.value == *value) + { + panic!("Expected attribute not found: {} = {:?}", key, value); + } + } +} diff --git a/sdk/typespec/typespec_client_core/src/http/method.rs b/sdk/typespec/typespec_client_core/src/http/method.rs index 3c185d66af..30f4f362a3 100644 --- a/sdk/typespec/typespec_client_core/src/http/method.rs +++ b/sdk/typespec/typespec_client_core/src/http/method.rs @@ -110,6 +110,18 @@ impl Method { pub fn is_safe(&self) -> bool { matches!(self, Method::Get | Method::Head) } + + /// Returns the HTTP method as a static string slice. + pub const fn as_str(&self) -> &'static str { + match self { + Method::Delete => "DELETE", + Method::Get => "GET", + Method::Head => "HEAD", + Method::Patch => "PATCH", + Method::Post => "POST", + Method::Put => "PUT", + } + } } #[cfg(any(feature = "json", feature = "xml"))] @@ -194,7 +206,7 @@ impl<'a> std::convert::TryFrom<&'a str> for Method { } impl AsRef for Method { - fn as_ref(&self) -> &str { + fn as_ref(&self) -> &'static str { match self { Self::Delete => "DELETE", Self::Get => "GET", diff --git a/sdk/typespec/typespec_client_core/src/http/pipeline.rs b/sdk/typespec/typespec_client_core/src/http/pipeline.rs index d132ff7ff7..75017fdd57 100644 --- a/sdk/typespec/typespec_client_core/src/http/pipeline.rs +++ b/sdk/typespec/typespec_client_core/src/http/pipeline.rs @@ -3,7 +3,7 @@ use crate::http::{ policies::{CustomHeadersPolicy, Policy, TransportPolicy}, - ClientOptions, Context, RawResponse, Request, RetryOptions, + ClientOptions, Context, RawResponse, Request, }; use std::sync::Arc; @@ -49,8 +49,7 @@ impl Pipeline { pipeline.extend_from_slice(&per_call_policies); pipeline.extend_from_slice(&options.per_call_policies); - // TODO: Consider whether this should be initially customizable as we onboard more services. - let retry_policy = RetryOptions::default().to_policy(); + let retry_policy = options.retry.unwrap_or_default().to_policy(); pipeline.push(retry_policy); pipeline.push(Arc::new(CustomHeadersPolicy::default())); diff --git a/sdk/typespec/typespec_client_core/src/http/policies/retry/mod.rs b/sdk/typespec/typespec_client_core/src/http/policies/retry/mod.rs index b4ea0e5098..f50a14fdf0 100644 --- a/sdk/typespec/typespec_client_core/src/http/policies/retry/mod.rs +++ b/sdk/typespec/typespec_client_core/src/http/policies/retry/mod.rs @@ -20,7 +20,7 @@ use crate::{ time::{self, Duration, OffsetDateTime}, }; use async_trait::async_trait; -use std::sync::Arc; +use std::{ops::Deref, sync::Arc}; use tracing::{debug, trace}; use typespec::error::{Error, ErrorKind, ResultExt}; @@ -68,6 +68,19 @@ pub fn get_retry_after(headers: &Headers, now: DateTimeFn) -> Option { }) } +/// A wrapper around a retry count to be used in the context of a retry policy. +/// +/// This allows a post-retry policy to access the retry count +pub struct RetryPolicyCount(u32); + +impl Deref for RetryPolicyCount { + type Target = u32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + /// A retry policy. /// /// In the simple form, the policies need only differ in how @@ -131,7 +144,8 @@ where "failed to reset body stream before retrying request", )?; } - let result = next[0].send(ctx, request, &next[1..]).await; + let ctx = ctx.clone().with_value(RetryPolicyCount(retry_count)); + let result = next[0].send(&ctx, request, &next[1..]).await; // only start keeping track of time after the first request is made let start = start.get_or_insert_with(OffsetDateTime::now_utc); let (last_error, retry_after) = match result { diff --git a/sdk/typespec/typespec_client_core/src/lib.rs b/sdk/typespec/typespec_client_core/src/lib.rs index da7b18a5e1..0baa57dfc5 100644 --- a/sdk/typespec/typespec_client_core/src/lib.rs +++ b/sdk/typespec/typespec_client_core/src/lib.rs @@ -18,6 +18,7 @@ pub mod json; pub mod sleep; pub mod stream; pub mod time; +#[cfg(feature = "http")] pub mod tracing; #[cfg(feature = "xml")] pub mod xml; diff --git a/sdk/typespec/typespec_client_core/src/tracing/attributes.rs b/sdk/typespec/typespec_client_core/src/tracing/attributes.rs index 66401bdd06..9b94fb9a18 100644 --- a/sdk/typespec/typespec_client_core/src/tracing/attributes.rs +++ b/sdk/typespec/typespec_client_core/src/tracing/attributes.rs @@ -1,17 +1,191 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +use std::borrow::Cow; + +/// An array of homogeneous attribute values. +#[derive(Debug, PartialEq, Clone)] pub enum AttributeArray { + /// An array of boolean values. Bool(Vec), + /// An array of 64-bit signed integers. I64(Vec), - U64(Vec), + /// An array of 64bit floating point values. + F64(Vec), + /// An array of strings. String(Vec), } +/// Represents a single attribute value, which can be of various types +#[derive(Debug, PartialEq, Clone)] pub enum AttributeValue { + /// A boolean attribute value. Bool(bool), + /// A signed 64-bit integer attribute value. I64(i64), - U64(u64), + /// A 64-bit floating point attribute value + F64(f64), + /// A string attribute value. String(String), + /// An array of attribute values. Array(AttributeArray), } + +/// Represents a key-value pair attribute, which is used for tracing and telemetry. +/// +/// Attributes are used to provide additional context and metadata about a span or event. +/// They can be of various types, including strings, integers, booleans, and arrays. +/// +/// Attributes are typically used to enrich telemetry data with additional information +/// that can be useful for debugging, monitoring, and analysis. +#[derive(Debug, PartialEq, Clone)] +pub struct Attribute { + /// A key-value pair attribute. + pub key: Cow<'static, str>, + pub value: AttributeValue, +} + +impl PartialEq<&str> for AttributeValue { + fn eq(&self, other: &&str) -> bool { + match self { + AttributeValue::String(s) => s == *other, + _ => false, + } + } +} + +impl PartialEq for AttributeValue { + fn eq(&self, other: &i64) -> bool { + match self { + AttributeValue::I64(i) => i == other, + _ => false, + } + } +} + +impl From for AttributeValue { + fn from(value: String) -> Self { + AttributeValue::String(value) + } +} + +impl From<&str> for AttributeValue { + fn from(value: &str) -> Self { + AttributeValue::String(value.to_string()) + } +} + +impl From for AttributeValue { + fn from(value: bool) -> Self { + AttributeValue::Bool(value) + } +} + +impl From for AttributeValue { + fn from(value: i32) -> Self { + AttributeValue::I64(value as i64) + } +} + +impl From for AttributeValue { + fn from(value: u16) -> Self { + AttributeValue::I64(value as i64) + } +} + +impl From for AttributeValue { + fn from(value: u32) -> Self { + AttributeValue::I64(value as i64) + } +} + +impl From for AttributeValue { + fn from(value: i64) -> Self { + AttributeValue::I64(value) + } +} + +impl From for AttributeValue { + fn from(value: f64) -> Self { + AttributeValue::F64(value) + } +} + +impl From> for AttributeValue { + fn from(value: Vec) -> Self { + AttributeValue::Array(AttributeArray::Bool(value)) + } +} + +impl From> for AttributeValue { + fn from(value: Vec) -> Self { + AttributeValue::Array(AttributeArray::I64(value)) + } +} + +impl From> for AttributeValue { + fn from(value: Vec) -> Self { + AttributeValue::Array(AttributeArray::F64(value)) + } +} + +impl From> for AttributeValue { + fn from(value: Vec) -> Self { + AttributeValue::Array(AttributeArray::String(value)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_attribute_value_equality() { + let attr1 = AttributeValue::String("test".into()); + let attr2 = AttributeValue::String("test".into()); + let attr3 = AttributeValue::String("different".into()); + + assert_eq!(attr1, attr2); + assert_ne!(attr1, attr3); + } + + #[test] + fn test_attribute_array_equality() { + let array1 = AttributeArray::String(vec!["test".into(), "test2".into()]); + let array2 = AttributeArray::String(vec!["test".into(), "test2".into()]); + let array3 = AttributeArray::String(vec!["different".into()]); + + assert_eq!(array1, array2); + assert_ne!(array1, array3); + } + + #[test] + fn test_attribute_key_from_string() { + let key = "test_key".to_string(); + let key = key + " value"; + let attr = Attribute { + key: key.into(), + value: AttributeValue::String("test_value".into()), + }; + assert_eq!(attr.key, "test_key value"); + } + + #[test] + fn test_attribute_equality() { + let attr1 = Attribute { + key: "test".into(), + value: AttributeValue::String("value".into()), + }; + let attr2 = Attribute { + key: "test".into(), + value: AttributeValue::String("value".into()), + }; + let attr3 = Attribute { + key: "test".into(), + value: AttributeValue::String("different".into()), + }; + + assert_eq!(attr1, attr2); + assert_ne!(attr1, attr3); + } +} diff --git a/sdk/typespec/typespec_client_core/src/tracing/mod.rs b/sdk/typespec/typespec_client_core/src/tracing/mod.rs index bc2c25c59f..3fb9a88078 100644 --- a/sdk/typespec/typespec_client_core/src/tracing/mod.rs +++ b/sdk/typespec/typespec_client_core/src/tracing/mod.rs @@ -3,9 +3,8 @@ //! Distributed tracing trait definitions //! -use crate::http::Context; -use crate::Result; -use std::sync::Arc; +use crate::http::{Context, Request}; +use std::{fmt::Debug, sync::Arc}; /// Overall architecture for distributed tracing in the SDK. /// @@ -19,77 +18,92 @@ use std::sync::Arc; mod attributes; mod with_context; -pub use attributes::{AttributeArray, AttributeValue}; +pub use attributes::{Attribute, AttributeArray, AttributeValue}; pub use with_context::{FutureExt, WithContext}; /// The TracerProvider trait is the entrypoint for distributed tracing in the SDK. /// /// It provides a method to get a tracer for a specific name and package version. -pub trait TracerProvider { +pub trait TracerProvider: Send + Sync + Debug { /// Returns a tracer for the given name. /// /// Arguments: - /// - `package_name`: The name of the package for which the tracer is requested. - /// - `package_version`: The version of the package for which the tracer is requested. + /// - `namespace_name`: The namespace of the package for which the tracer is requested. See + /// [this page](https://learn.microsoft.com/azure/azure-resource-manager/management/azure-services-resource-providers) + /// for more information on namespace names. + /// - `crate_name`: The name of the crate for which the tracer is requested. + /// - `crate_version`: The version of the crate for which the tracer is requested. fn get_tracer( &self, - package_name: &'static str, - package_version: &'static str, - ) -> Box; + namespace_name: Option<&'static str>, + crate_name: &'static str, + crate_version: Option<&'static str>, + ) -> Arc; } -pub trait Tracer { +pub trait Tracer: Send + Sync + Debug { /// Starts a new span with the given name and type. /// - /// # Arguments - /// - `name`: The name of the span to start. - /// - `kind`: The type of the span to start. - /// - /// # Returns - /// An `Arc` representing the started span. - /// - fn start_span(&self, name: &'static str, kind: SpanKind) - -> Result>; - - /// Starts a new span with the given type, using the current span as the parent span. + /// The newly created span will have the "current" span as a parent. /// /// # Arguments /// - `name`: The name of the span to start. /// - `kind`: The type of the span to start. + /// - `attributes`: A vector of attributes to associate with the span. /// /// # Returns - /// An `Arc` representing the started span. + /// An `Arc` representing the started span. /// - fn start_span_with_current( + fn start_span( &self, name: &'static str, kind: SpanKind, - ) -> Result>; + attributes: Vec, + ) -> Arc; /// Starts a new child with the given name, type, and parent span. /// /// # Arguments /// - `name`: The name of the span to start. /// - `kind`: The type of the span to start. + /// - `attributes`: A vector of attributes to associate with the span. /// - `parent`: The parent span to use for the new span. /// /// # Returns - /// An `Arc` representing the started span + /// An `Arc` representing the started span + /// + /// Note: This method may panic if the parent span cannot be downcasted to the expected type. /// fn start_span_with_parent( &self, name: &'static str, kind: SpanKind, - parent: Arc, - ) -> Result>; + attributes: Vec, + parent: Arc, + ) -> Arc; + + /// Returns the namespace the tracer was configured with (if any). + /// + /// # Returns + /// An `Option<&'static str>` representing the namespace of the tracer, + fn namespace(&self) -> Option<&'static str>; } + +/// The status of a span. +/// +/// This enum represents the possible statuses of a span in distributed tracing. +/// It can be either `Unset`, indicating that the span has not been set to any specific status, +/// or `Error`, which contains a description of the error that occurred during the span's execution +/// +/// Note that OpenTelemetry defines an `Ok` status but that status is reserved for application and service developers, +/// so libraries should never set it. +#[derive(Debug, PartialEq)] pub enum SpanStatus { Unset, - Ok, Error { description: String }, } -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq, Eq)] pub enum SpanKind { #[default] Internal, @@ -101,15 +115,21 @@ pub enum SpanKind { pub trait SpanGuard { /// Ends the span when dropped. - fn end(self) -> crate::Result<()>; + fn end(self); } -pub trait Span: AsAny { +/// A trait that represents a span in distributed tracing. +/// +/// This trait defines the methods that a span must implement to be used in distributed tracing. +/// It includes methods for setting attributes, recording errors, and managing the span's lifecycle. +pub trait Span: AsAny + Send + Sync { + fn is_recording(&self) -> bool; + /// The 8 byte value which identifies the span. fn span_id(&self) -> [u8; 8]; /// Ends the current span. - fn end(&self) -> crate::Result<()>; + fn end(&self); /// Sets the status of the current span. /// # Arguments @@ -118,14 +138,15 @@ pub trait Span: AsAny { /// # Returns /// A `Result` indicating success or failure of the operation. /// - fn set_status(&self, status: SpanStatus) -> crate::Result<()>; + fn set_status(&self, status: SpanStatus); /// Sets an attribute on the current span. - fn set_attribute( - &self, - key: &'static str, - value: attributes::AttributeValue, - ) -> crate::Result<()>; + /// + /// # Arguments + /// - `key`: The key of the attribute to set. + /// - `value`: The value of the attribute to set. + /// + fn set_attribute(&self, key: &'static str, value: attributes::AttributeValue); /// Records a Rust standard error on the current span. /// @@ -135,9 +156,10 @@ pub trait Span: AsAny { /// # Returns /// A `Result` indicating success or failure of the operation. /// - fn record_error(&self, error: &dyn std::error::Error) -> crate::Result<()>; + fn record_error(&self, error: &dyn std::error::Error); /// Temporarily sets the span as the current active span in the context. + /// /// # Arguments /// - `context`: The context in which to set the current span. /// @@ -147,7 +169,18 @@ pub trait Span: AsAny { /// This method allows the span to be set as the current span in the context, /// enabling it to be used for tracing operations within that context. /// - fn set_current(&self, context: &Context) -> crate::Result>; + fn set_current(&self, context: &Context) -> Box; + + /// Adds telemetry headers to the request for distributed tracing. + /// + /// # Arguments + /// - `request`: A mutable reference to the request to which headers will be added. + /// + /// This method should be called before sending the request to ensure that the tracing information + /// is included in the request headers. It typically adds the [W3C Distributed Tracing](https://www.w3.org/TR/trace-context/) + /// headers to the request. + /// + fn propagate_headers(&self, request: &mut Request); } /// A trait that allows an object to be downcast to a reference of type `Any`. diff --git a/sdk/typespec/typespec_client_core/src/tracing/with_context.rs b/sdk/typespec/typespec_client_core/src/tracing/with_context.rs index 14ef2808bf..8dfcf050b6 100644 --- a/sdk/typespec/typespec_client_core/src/tracing/with_context.rs +++ b/sdk/typespec/typespec_client_core/src/tracing/with_context.rs @@ -20,8 +20,8 @@ impl std::future::Future for WithContext<'_, T> { fn poll(self: Pin<&mut Self>, task_cx: &mut TaskContext<'_>) -> Poll { let this = self.project(); - if let Some(span) = this.context.value::>() { - let _guard = span.set_current(this.context).unwrap(); + if let Some(span) = this.context.value::>() { + let _guard = span.set_current(this.context); this.inner.poll(task_cx) } else {