Skip to content

Added RateLimitedSampler #41925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

- Detect synthetically created telemetry based on the user-agent header
([#41733](https://github.com/Azure/azure-sdk-for-python/pull/41733))
- Added RateLimited Sampler
([#41925](https://github.com/Azure/azure-sdk-for-python/pull/41925))

### Breaking Changes

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from azure.monitor.opentelemetry.exporter.export.metrics._exporter import AzureMonitorMetricExporter
from azure.monitor.opentelemetry.exporter.export.trace._exporter import AzureMonitorTraceExporter
from azure.monitor.opentelemetry.exporter.export.trace._sampling import ApplicationInsightsSampler
from azure.monitor.opentelemetry.exporter.export.trace._rate_limited_sampling import RateLimitedSampler
from ._version import VERSION

__all__ = [
"ApplicationInsightsSampler",
"AzureMonitorMetricExporter",
"AzureMonitorLogExporter",
"AzureMonitorTraceExporter",
"RateLimitedSampler",
]
__version__ = VERSION
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@
# sampleRate

_SAMPLE_RATE_KEY = "_MS.sampleRate"
_HASH = 5381
_INTEGER_MAX: int = 2**31 - 1
_INTEGER_MIN: int = -2**31

# AAD Auth

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import math
import threading
import time
from typing import Callable, Optional, Sequence
from opentelemetry.context import Context
from opentelemetry.trace import Link, SpanKind, format_trace_id, get_current_span
from opentelemetry.sdk.trace.sampling import (
Decision,
Sampler,
SamplingResult,
_get_parent_trace_state,
)
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes

from azure.monitor.opentelemetry.exporter._constants import _SAMPLE_RATE_KEY

from azure.monitor.opentelemetry.exporter.export.trace._utils import _get_djb2_sample_score, _round_down_to_nearest

class _State:
def __init__(self, effective_window_count: float, effective_window_nanos: float, last_nano_time: int):
self.effective_window_count = effective_window_count
self.effective_window_nanos = effective_window_nanos
self.last_nano_time = last_nano_time

class RateLimitedSamplingPercentage:
def __init__(self, target_spans_per_second_limit: float,
nano_time_supplier: Optional[Callable[[], int]] = None, round_to_nearest: bool = True):
if target_spans_per_second_limit < 0.0:
raise ValueError("Limit for sampled spans per second must be nonnegative!")
self._nano_time_supplier = nano_time_supplier or (lambda: int(time.time_ns()))
# Hardcoded adaptation time of 0.1 seconds for adjusting to sudden changes in telemetry volumes
adaptation_time_seconds = 0.1
self._inverse_adaptation_time_nanos = 1e-9 / adaptation_time_seconds
self._target_spans_per_nanosecond_limit = 1e-9 * target_spans_per_second_limit
initial_nano_time = self._nano_time_supplier()
self._state = _State(0.0, 0.0, initial_nano_time)
self._lock = threading.Lock()
self._round_to_nearest = round_to_nearest

def _update_state(self, old_state: _State, current_nano_time: int) -> _State:
if current_nano_time <= old_state.last_nano_time:
return _State(
old_state.effective_window_count + 1,
old_state.effective_window_nanos,
old_state.last_nano_time
)
nano_time_delta = current_nano_time - old_state.last_nano_time
decay_factor = math.exp(-nano_time_delta * self._inverse_adaptation_time_nanos)
current_effective_window_count = old_state.effective_window_count * decay_factor + 1
current_effective_window_nanos = old_state.effective_window_nanos * decay_factor + nano_time_delta

return _State(current_effective_window_count, current_effective_window_nanos, current_nano_time)

def get(self) -> float:
"""Get the current sampling percentage (0.0 to 100.0)."""
current_nano_time = self._nano_time_supplier()

with self._lock:
old_state = self._state
self._state = self._update_state(self._state, current_nano_time)
current_state = self._state

# Calculate sampling probability based on current state
if current_state.effective_window_count == 0:
return 100.0

sampling_probability = (
(current_state.effective_window_nanos * self._target_spans_per_nanosecond_limit) /
current_state.effective_window_count
)

sampling_percentage = 100 * min(sampling_probability, 1.0)

if self._round_to_nearest:
sampling_percentage = _round_down_to_nearest(sampling_percentage)

return sampling_percentage


class RateLimitedSampler(Sampler):
def __init__(self, target_spans_per_second_limit: float):
self._sampling_percentage_generator = RateLimitedSamplingPercentage(target_spans_per_second_limit)
self._description = f"RateLimitedSampler{{{target_spans_per_second_limit}}}"

def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: Optional[SpanKind] = None,
attributes: Attributes = None,
links: Optional[Sequence["Link"]] = None,
trace_state: Optional["TraceState"] = None,
) -> "SamplingResult":

if parent_context is not None:
parent_span = get_current_span(parent_context)
parent_span_context = parent_span.get_span_context()

# Check if parent is valid and local (not remote)
if parent_span_context.is_valid and not parent_span_context.is_remote:
# Check if parent was dropped/record-only first
if not parent_span.is_recording():
# Parent was dropped, drop this child too
if attributes is None:
new_attributes = {}
else:
new_attributes = dict(attributes)
new_attributes[_SAMPLE_RATE_KEY] = 0.0

return SamplingResult(
Decision.DROP,
new_attributes,
_get_parent_trace_state(parent_context),
)

# Parent is recording, check for sample rate attribute
parent_attributes = getattr(parent_span, 'attributes', {})
parent_sample_rate = parent_attributes.get(_SAMPLE_RATE_KEY)

if parent_sample_rate is not None:
# Honor parent's sampling rate
if attributes is None:
new_attributes = {}
else:
new_attributes = dict(attributes)
new_attributes[_SAMPLE_RATE_KEY] = parent_sample_rate

return SamplingResult(
Decision.RECORD_AND_SAMPLE,
new_attributes,
_get_parent_trace_state(parent_context),
)

sampling_percentage = self._sampling_percentage_generator.get()
sampling_score = _get_djb2_sample_score(format_trace_id(trace_id).lower())

if sampling_score < sampling_percentage:
decision = Decision.RECORD_AND_SAMPLE
else:
decision = Decision.DROP

if attributes is None:
new_attributes = {}
else:
new_attributes = dict(attributes)
new_attributes[_SAMPLE_RATE_KEY] = sampling_percentage

return SamplingResult(
decision,
new_attributes,
_get_parent_trace_state(parent_context),
)

def get_description(self) -> str:
return self._description
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import no_type_check, Optional, Tuple
from urllib.parse import urlparse
import math

from opentelemetry.semconv.attributes import (
client_attributes,
Expand All @@ -13,6 +14,11 @@
from opentelemetry.semconv.trace import DbSystemValues, SpanAttributes
from opentelemetry.util.types import Attributes

from azure.monitor.opentelemetry.exporter._constants import (
_HASH,
_INTEGER_MAX,
_INTEGER_MIN
)

# pylint:disable=too-many-return-statements
def _get_default_port_db(db_system: str) -> int:
Expand Down Expand Up @@ -320,3 +326,33 @@ def _get_url_for_http_request(attributes: Attributes) -> Optional[str]:
http_target,
)
return url

def _get_djb2_sample_score(trace_id_hex: str) -> float:
# This algorithm uses 32bit integers
hash_value = _HASH
for char in trace_id_hex:
hash_value = ((hash_value << 5) + hash_value) + ord(char)
hash_value &= 0xFFFFFFFF # simulate 32-bit integer overflow

# Convert to signed 32-bit int
if hash_value & 0x80000000:
hash_value = -((~hash_value & 0xFFFFFFFF) + 1)

if hash_value == _INTEGER_MIN:
hash_value = int(_INTEGER_MAX)
else:
hash_value = abs(hash_value)

return 100.0 * (float(hash_value) / _INTEGER_MAX)

def _round_down_to_nearest(sampling_percentage: float) -> float:
if sampling_percentage == 0:
return 0
# Handle extremely small percentages that would cause overflow
if sampling_percentage <= _INTEGER_MIN: # Extremely small threshold
return 0.0
item_count = 100.0 / sampling_percentage
# Handle case where item_count is infinity or too large for math.ceil
if not math.isfinite(item_count) or item_count >= _INTEGER_MAX:
return 0.0
return 100.0 / math.ceil(item_count)
Loading
Loading