diff --git a/sentry_sdk/client.py b/sentry_sdk/client.py index 6cb5ca5826..d60a099fbe 100644 --- a/sentry_sdk/client.py +++ b/sentry_sdk/client.py @@ -31,6 +31,7 @@ ) from sentry_sdk.serializer import serialize from sentry_sdk.tracing import trace +from sentry_sdk.tracing_utils import _generate_sample_rand from sentry_sdk.transport import BaseHttpTransport, make_transport from sentry_sdk.consts import ( SPANDATA, @@ -181,7 +182,9 @@ class BaseClient: def __init__(self, options=None): # type: (Optional[Dict[str, Any]]) -> None - self.options = options if options is not None else DEFAULT_OPTIONS # type: Dict[str, Any] + self.options = ( + options if options is not None else DEFAULT_OPTIONS + ) # type: Dict[str, Any] self.transport = None # type: Optional[Transport] self.monitor = None # type: Optional[Monitor] @@ -614,7 +617,9 @@ def _prepare_event( event_scrubber.scrub_event(event) if scope is not None and scope._gen_ai_original_message_count: - spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue + spans = event.get( + "spans", [] + ) # type: List[Dict[str, Any]] | AnnotatedValue if isinstance(spans, list): for span in spans: span_id = span.get("span_id", None) @@ -1000,6 +1005,50 @@ def _capture_metric(self, metric): current_scope = sentry_sdk.get_current_scope() isolation_scope = sentry_sdk.get_isolation_scope() + # Determine trace_id and span_id using the same logic as the original metrics.py + trace_id = None + span_id = None + if current_scope.span is not None: + trace_id = current_scope.span.trace_id + span_id = current_scope.span.span_id + elif current_scope._propagation_context is not None: + trace_id = current_scope._propagation_context.trace_id + span_id = current_scope._propagation_context.span_id + + sample_rate = metric["attributes"].get("sentry.client_sample_rate") + if sample_rate is not None: + sample_rate = float(sample_rate) + + # Always validate sample_rate range, regardless of trace context + if sample_rate <= 0.0 or sample_rate > 1.0: + if self.transport is not None: + self.transport.record_lost_event( + "invalid_sample_rate", + data_category="trace_metric", + quantity=1, + ) + return + + # If there's no trace context, remove the sample_rate attribute and continue + if trace_id is None: + del metric["attributes"]["sentry.client_sample_rate"] + else: + # There is a trace context, apply sampling logic + if sample_rate < 1.0: + sample_rand = _generate_sample_rand(trace_id) + if sample_rand >= sample_rate: + if self.transport is not None: + self.transport.record_lost_event( + "sample_rate", + data_category="trace_metric", + quantity=1, + ) + return + + # If sample_rate is 1.0, remove the attribute as it's implied + if sample_rate == 1.0: + del metric["attributes"]["sentry.client_sample_rate"] + metric["attributes"]["sentry.sdk.name"] = SDK_INFO["name"] metric["attributes"]["sentry.sdk.version"] = SDK_INFO["version"] @@ -1011,10 +1060,6 @@ def _capture_metric(self, metric): if release is not None and "sentry.release" not in metric["attributes"]: metric["attributes"]["sentry.release"] = release - trace_context = current_scope.get_trace_context() - trace_id = trace_context.get("trace_id") - span_id = trace_context.get("span_id") - metric["trace_id"] = trace_id or "00000000-0000-0000-0000-000000000000" if span_id is not None: metric["span_id"] = span_id diff --git a/sentry_sdk/metrics.py b/sentry_sdk/metrics.py index 03bde137bd..a060cf7715 100644 --- a/sentry_sdk/metrics.py +++ b/sentry_sdk/metrics.py @@ -1,15 +1,16 @@ """ -NOTE: This file contains experimental code that may be changed or removed at any -time without prior notice. +NOTE: This file contains experimental code that may be changed or removed at +any time without prior notice. """ import time -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING import sentry_sdk from sentry_sdk.utils import safe_repr if TYPE_CHECKING: + from typing import Any, Optional, Union from sentry_sdk._types import Metric, MetricType @@ -19,6 +20,7 @@ def _capture_metric( value, # type: float unit=None, # type: Optional[str] attributes=None, # type: Optional[dict[str, Any]] + sample_rate=None, # type: Optional[float] ): # type: (...) -> None client = sentry_sdk.get_client() @@ -37,6 +39,9 @@ def _capture_metric( else safe_repr(v) ) + if sample_rate is not None: + attrs["sentry.client_sample_rate"] = sample_rate + metric = { "timestamp": time.time(), "trace_id": None, @@ -56,9 +61,10 @@ def count( value, # type: float unit=None, # type: Optional[str] attributes=None, # type: Optional[dict[str, Any]] + sample_rate=None, # type: Optional[float] ): # type: (...) -> None - _capture_metric(name, "counter", value, unit, attributes) + _capture_metric(name, "counter", value, unit, attributes, sample_rate) def gauge( @@ -66,9 +72,10 @@ def gauge( value, # type: float unit=None, # type: Optional[str] attributes=None, # type: Optional[dict[str, Any]] + sample_rate=None, # type: Optional[float] ): # type: (...) -> None - _capture_metric(name, "gauge", value, unit, attributes) + _capture_metric(name, "gauge", value, unit, attributes, sample_rate) def distribution( @@ -76,6 +83,7 @@ def distribution( value, # type: float unit=None, # type: Optional[str] attributes=None, # type: Optional[dict[str, Any]] + sample_rate=None, # type: Optional[float] ): # type: (...) -> None - _capture_metric(name, "distribution", value, unit, attributes) + _capture_metric(name, "distribution", value, unit, attributes, sample_rate) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index c7b786beb4..b192270d79 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,6 +1,7 @@ import json import sys from typing import List, Any, Mapping +from unittest import mock import pytest import sentry_sdk @@ -267,3 +268,116 @@ def record_lost_event(reason, data_category, quantity): assert len(lost_event_calls) == 5 for lost_event_call in lost_event_calls: assert lost_event_call == ("queue_overflow", "trace_metric", 1) + + +def test_metrics_sample_rate_basic(sentry_init, capture_envelopes): + sentry_init() + envelopes = capture_envelopes() + + sentry_sdk.metrics.count("test.counter", 1, sample_rate=0.5) + sentry_sdk.metrics.gauge("test.gauge", 42, sample_rate=0.8) + sentry_sdk.metrics.distribution("test.distribution", 200, sample_rate=1.0) + + get_client().flush() + metrics = envelopes_to_metrics(envelopes) + + assert len(metrics) == 3 + + assert metrics[0]["name"] == "test.counter" + # No sentry.client_sample_rate when there's no trace context + assert "sentry.client_sample_rate" not in metrics[0]["attributes"] + + assert metrics[1]["name"] == "test.gauge" + # No sentry.client_sample_rate when there's no trace context + assert "sentry.client_sample_rate" not in metrics[1]["attributes"] + + assert metrics[2]["name"] == "test.distribution" + assert "sentry.client_sample_rate" not in metrics[2]["attributes"] + + +def test_metrics_sample_rate_normalization(sentry_init, capture_envelopes, monkeypatch): + sentry_init() + envelopes = capture_envelopes() + client = sentry_sdk.get_client() + + lost_event_calls = [] + + def record_lost_event(reason, data_category, quantity): + lost_event_calls.append((reason, data_category, quantity)) + + monkeypatch.setattr(client.transport, "record_lost_event", record_lost_event) + + sentry_sdk.metrics.count("test.counter1", 1, sample_rate=0.0) # <= 0 + sentry_sdk.metrics.count("test.counter2", 1, sample_rate=-0.5) # < 0 + sentry_sdk.metrics.count("test.counter3", 1, sample_rate=0.5) # > 0 but < 1.0 + sentry_sdk.metrics.count("test.counter4", 1, sample_rate=1.0) # = 1.0 + sentry_sdk.metrics.count("test.counter4", 1, sample_rate=1.5) # > 1.0 + + client.flush() + metrics = envelopes_to_metrics(envelopes) + + assert len(metrics) == 2 + + # No sentry.client_sample_rate when there's no trace context + assert "sentry.client_sample_rate" not in metrics[0]["attributes"] + assert ( + "sentry.client_sample_rate" not in metrics[1]["attributes"] + ) # 1.0 does not need a sample rate, it's implied to be 1.0 + + assert len(lost_event_calls) == 3 + assert lost_event_calls[0] == ("invalid_sample_rate", "trace_metric", 1) + assert lost_event_calls[1] == ("invalid_sample_rate", "trace_metric", 1) + assert lost_event_calls[2] == ("invalid_sample_rate", "trace_metric", 1) + + +def test_metrics_no_sample_rate(sentry_init, capture_envelopes): + sentry_init() + envelopes = capture_envelopes() + + sentry_sdk.metrics.count("test.counter", 1) + + get_client().flush() + metrics = envelopes_to_metrics(envelopes) + + assert len(metrics) == 1 + + assert "sentry.client_sample_rate" not in metrics[0]["attributes"] + + +@pytest.mark.parametrize("sample_rand", (0.0, 0.25, 0.5, 0.75)) +@pytest.mark.parametrize("sample_rate", (0.0, 0.25, 0.5, 0.75, 1.0)) +def test_metrics_sampling_decision( + sentry_init, capture_envelopes, sample_rate, sample_rand, monkeypatch +): + sentry_init(traces_sample_rate=1.0) + envelopes = capture_envelopes() + client = sentry_sdk.get_client() + + lost_event_calls = [] + + def record_lost_event(reason, data_category, quantity): + lost_event_calls.append((reason, data_category, quantity)) + + monkeypatch.setattr(client.transport, "record_lost_event", record_lost_event) + + with mock.patch( + "sentry_sdk.tracing_utils.Random.randrange", + return_value=int(sample_rand * 1000000), + ): + with sentry_sdk.start_transaction() as _: + sentry_sdk.metrics.count("test.counter", 1, sample_rate=sample_rate) + + get_client().flush() + metrics = envelopes_to_metrics(envelopes) + + should_be_sampled = sample_rand < sample_rate and sample_rate > 0.0 + assert len(metrics) == int(should_be_sampled) + + if sample_rate <= 0.0: + assert len(lost_event_calls) == 1 + assert lost_event_calls[0] == ("invalid_sample_rate", "trace_metric", 1) + elif not should_be_sampled: + assert len(lost_event_calls) == 1 + assert lost_event_calls[0] == ("sample_rate", "trace_metric", 1) + else: + assert len(lost_event_calls) == 0