Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions sentry_sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
)
from sentry_sdk.serializer import serialize
from sentry_sdk.tracing import trace
from sentry_sdk.tracing_utils import _generate_sample_rand
from sentry_sdk.transport import BaseHttpTransport, make_transport
from sentry_sdk.consts import (
SPANDATA,
Expand Down Expand Up @@ -181,7 +182,9 @@ class BaseClient:

def __init__(self, options=None):
# type: (Optional[Dict[str, Any]]) -> None
self.options = options if options is not None else DEFAULT_OPTIONS # type: Dict[str, Any]
self.options = (
options if options is not None else DEFAULT_OPTIONS
) # type: Dict[str, Any]

self.transport = None # type: Optional[Transport]
self.monitor = None # type: Optional[Monitor]
Expand Down Expand Up @@ -614,7 +617,9 @@ def _prepare_event(
event_scrubber.scrub_event(event)

if scope is not None and scope._gen_ai_original_message_count:
spans = event.get("spans", []) # type: List[Dict[str, Any]] | AnnotatedValue
spans = event.get(
"spans", []
) # type: List[Dict[str, Any]] | AnnotatedValue
if isinstance(spans, list):
for span in spans:
span_id = span.get("span_id", None)
Expand Down Expand Up @@ -1000,6 +1005,50 @@ def _capture_metric(self, metric):
current_scope = sentry_sdk.get_current_scope()
isolation_scope = sentry_sdk.get_isolation_scope()

# Determine trace_id and span_id using the same logic as the original metrics.py
trace_id = None
span_id = None
if current_scope.span is not None:
trace_id = current_scope.span.trace_id
span_id = current_scope.span.span_id
elif current_scope._propagation_context is not None:
trace_id = current_scope._propagation_context.trace_id
span_id = current_scope._propagation_context.span_id

sample_rate = metric["attributes"].get("sentry.client_sample_rate")
if sample_rate is not None:
sample_rate = float(sample_rate)

# Always validate sample_rate range, regardless of trace context
if sample_rate <= 0.0 or sample_rate > 1.0:
if self.transport is not None:
self.transport.record_lost_event(
"invalid_sample_rate",
data_category="trace_metric",
quantity=1,
)
return

# If there's no trace context, remove the sample_rate attribute and continue
if trace_id is None:
del metric["attributes"]["sentry.client_sample_rate"]
else:
# There is a trace context, apply sampling logic
if sample_rate < 1.0:
sample_rand = _generate_sample_rand(trace_id)
if sample_rand >= sample_rate:
if self.transport is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would also add a debug log.

self.transport.record_lost_event(
"sample_rate",
data_category="trace_metric",
quantity=1,
)
return

# If sample_rate is 1.0, remove the attribute as it's implied
if sample_rate == 1.0:
del metric["attributes"]["sentry.client_sample_rate"]

metric["attributes"]["sentry.sdk.name"] = SDK_INFO["name"]
metric["attributes"]["sentry.sdk.version"] = SDK_INFO["version"]

Expand All @@ -1011,10 +1060,6 @@ def _capture_metric(self, metric):
if release is not None and "sentry.release" not in metric["attributes"]:
metric["attributes"]["sentry.release"] = release

trace_context = current_scope.get_trace_context()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please do not revert this and use the underlying get_trace_context as a source of the trace_id, span_id, this is intentional to centralize getting the active trace and span for other features.

trace_id = trace_context.get("trace_id")
span_id = trace_context.get("span_id")

metric["trace_id"] = trace_id or "00000000-0000-0000-0000-000000000000"
if span_id is not None:
metric["span_id"] = span_id
Expand Down
20 changes: 14 additions & 6 deletions sentry_sdk/metrics.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""
NOTE: This file contains experimental code that may be changed or removed at any
time without prior notice.
NOTE: This file contains experimental code that may be changed or removed at
any time without prior notice.
"""

import time
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import TYPE_CHECKING

import sentry_sdk
from sentry_sdk.utils import safe_repr

if TYPE_CHECKING:
from typing import Any, Optional, Union
from sentry_sdk._types import Metric, MetricType


Expand All @@ -19,6 +20,7 @@ def _capture_metric(
value, # type: float
unit=None, # type: Optional[str]
attributes=None, # type: Optional[dict[str, Any]]
sample_rate=None, # type: Optional[float]
):
# type: (...) -> None
client = sentry_sdk.get_client()
Expand All @@ -37,6 +39,9 @@ def _capture_metric(
else safe_repr(v)
)

if sample_rate is not None:
attrs["sentry.client_sample_rate"] = sample_rate

metric = {
"timestamp": time.time(),
"trace_id": None,
Expand All @@ -56,26 +61,29 @@ def count(
value, # type: float
unit=None, # type: Optional[str]
attributes=None, # type: Optional[dict[str, Any]]
sample_rate=None, # type: Optional[float]
):
# type: (...) -> None
_capture_metric(name, "counter", value, unit, attributes)
_capture_metric(name, "counter", value, unit, attributes, sample_rate)


def gauge(
name, # type: str
value, # type: float
unit=None, # type: Optional[str]
attributes=None, # type: Optional[dict[str, Any]]
sample_rate=None, # type: Optional[float]
):
# type: (...) -> None
_capture_metric(name, "gauge", value, unit, attributes)
_capture_metric(name, "gauge", value, unit, attributes, sample_rate)


def distribution(
name, # type: str
value, # type: float
unit=None, # type: Optional[str]
attributes=None, # type: Optional[dict[str, Any]]
sample_rate=None, # type: Optional[float]
):
# type: (...) -> None
_capture_metric(name, "distribution", value, unit, attributes)
_capture_metric(name, "distribution", value, unit, attributes, sample_rate)
114 changes: 114 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import sys
from typing import List, Any, Mapping
from unittest import mock
import pytest

import sentry_sdk
Expand Down Expand Up @@ -267,3 +268,116 @@ def record_lost_event(reason, data_category, quantity):
assert len(lost_event_calls) == 5
for lost_event_call in lost_event_calls:
assert lost_event_call == ("queue_overflow", "trace_metric", 1)


def test_metrics_sample_rate_basic(sentry_init, capture_envelopes):
sentry_init()
envelopes = capture_envelopes()

sentry_sdk.metrics.count("test.counter", 1, sample_rate=0.5)
sentry_sdk.metrics.gauge("test.gauge", 42, sample_rate=0.8)
sentry_sdk.metrics.distribution("test.distribution", 200, sample_rate=1.0)

get_client().flush()
metrics = envelopes_to_metrics(envelopes)

assert len(metrics) == 3

assert metrics[0]["name"] == "test.counter"
# No sentry.client_sample_rate when there's no trace context
assert "sentry.client_sample_rate" not in metrics[0]["attributes"]

assert metrics[1]["name"] == "test.gauge"
# No sentry.client_sample_rate when there's no trace context
assert "sentry.client_sample_rate" not in metrics[1]["attributes"]

assert metrics[2]["name"] == "test.distribution"
assert "sentry.client_sample_rate" not in metrics[2]["attributes"]


def test_metrics_sample_rate_normalization(sentry_init, capture_envelopes, monkeypatch):
sentry_init()
envelopes = capture_envelopes()
client = sentry_sdk.get_client()

lost_event_calls = []

def record_lost_event(reason, data_category, quantity):
lost_event_calls.append((reason, data_category, quantity))

monkeypatch.setattr(client.transport, "record_lost_event", record_lost_event)

sentry_sdk.metrics.count("test.counter1", 1, sample_rate=0.0) # <= 0
sentry_sdk.metrics.count("test.counter2", 1, sample_rate=-0.5) # < 0
sentry_sdk.metrics.count("test.counter3", 1, sample_rate=0.5) # > 0 but < 1.0
sentry_sdk.metrics.count("test.counter4", 1, sample_rate=1.0) # = 1.0
sentry_sdk.metrics.count("test.counter4", 1, sample_rate=1.5) # > 1.0

client.flush()
metrics = envelopes_to_metrics(envelopes)

assert len(metrics) == 2

# No sentry.client_sample_rate when there's no trace context
assert "sentry.client_sample_rate" not in metrics[0]["attributes"]
assert (
"sentry.client_sample_rate" not in metrics[1]["attributes"]
) # 1.0 does not need a sample rate, it's implied to be 1.0

assert len(lost_event_calls) == 3
assert lost_event_calls[0] == ("invalid_sample_rate", "trace_metric", 1)
assert lost_event_calls[1] == ("invalid_sample_rate", "trace_metric", 1)
assert lost_event_calls[2] == ("invalid_sample_rate", "trace_metric", 1)


def test_metrics_no_sample_rate(sentry_init, capture_envelopes):
sentry_init()
envelopes = capture_envelopes()

sentry_sdk.metrics.count("test.counter", 1)

get_client().flush()
metrics = envelopes_to_metrics(envelopes)

assert len(metrics) == 1

assert "sentry.client_sample_rate" not in metrics[0]["attributes"]


@pytest.mark.parametrize("sample_rand", (0.0, 0.25, 0.5, 0.75))
@pytest.mark.parametrize("sample_rate", (0.0, 0.25, 0.5, 0.75, 1.0))
def test_metrics_sampling_decision(
sentry_init, capture_envelopes, sample_rate, sample_rand, monkeypatch
):
sentry_init(traces_sample_rate=1.0)
envelopes = capture_envelopes()
client = sentry_sdk.get_client()

lost_event_calls = []

def record_lost_event(reason, data_category, quantity):
lost_event_calls.append((reason, data_category, quantity))

monkeypatch.setattr(client.transport, "record_lost_event", record_lost_event)

with mock.patch(
"sentry_sdk.tracing_utils.Random.randrange",
return_value=int(sample_rand * 1000000),
):
with sentry_sdk.start_transaction() as _:
sentry_sdk.metrics.count("test.counter", 1, sample_rate=sample_rate)

get_client().flush()
metrics = envelopes_to_metrics(envelopes)

should_be_sampled = sample_rand < sample_rate and sample_rate > 0.0
assert len(metrics) == int(should_be_sampled)

if sample_rate <= 0.0:
assert len(lost_event_calls) == 1
assert lost_event_calls[0] == ("invalid_sample_rate", "trace_metric", 1)
elif not should_be_sampled:
assert len(lost_event_calls) == 1
assert lost_event_calls[0] == ("sample_rate", "trace_metric", 1)
else:
assert len(lost_event_calls) == 0
Loading