1"""
2Async-compatible API for recording observability metrics.
3
4This module provides an async-safe interface for Redis async client code to record
5metrics without needing to know about OpenTelemetry internals. It reuses the same
6RedisMetricsCollector and configuration as the sync recorder.
7
8Usage in Redis async client code:
9 from redis.asyncio.observability.recorder import record_operation_duration
10
11 start_time = time.monotonic()
12 # ... execute Redis command ...
13 await record_operation_duration(
14 command_name='SET',
15 duration_seconds=time.monotonic() - start_time,
16 server_address='localhost',
17 server_port=6379,
18 db_namespace='0',
19 error=None
20 )
21"""
22
23from datetime import datetime
24from typing import TYPE_CHECKING, List, Optional
25
26from redis.observability.attributes import (
27 GeoFailoverReason,
28 PubSubDirection,
29)
30from redis.observability.metrics import CloseReason, RedisMetricsCollector
31from redis.observability.providers import get_observability_instance
32from redis.observability.registry import get_observables_registry_instance
33from redis.utils import str_if_bytes
34
35if TYPE_CHECKING:
36 from redis.asyncio.connection import ConnectionPool
37 from redis.asyncio.multidb.database import AsyncDatabase
38 from redis.observability.config import OTelConfig
39
40# Global metrics collector instance (lazy-initialized)
41_async_metrics_collector: Optional[RedisMetricsCollector] = None
42
43CONNECTION_COUNT_REGISTRY_KEY = "connection_count"
44
45
46def _get_or_create_collector() -> Optional[RedisMetricsCollector]:
47 """
48 Get or create the global metrics collector.
49
50 Returns:
51 RedisMetricsCollector instance if observability is enabled, None otherwise
52 """
53 global _async_metrics_collector
54
55 if _async_metrics_collector is not None:
56 return _async_metrics_collector
57
58 try:
59 manager = get_observability_instance().get_provider_manager()
60 if manager is None or not manager.config.enabled_telemetry:
61 return None
62
63 # Get meter from the global MeterProvider
64 meter = manager.get_meter_provider().get_meter(
65 RedisMetricsCollector.METER_NAME, RedisMetricsCollector.METER_VERSION
66 )
67
68 _async_metrics_collector = RedisMetricsCollector(meter, manager.config)
69 return _async_metrics_collector
70
71 except ImportError:
72 # Observability module not available
73 return None
74 except Exception:
75 # Any other error - don't break Redis operations
76 return None
77
78
79async def _get_config() -> Optional["OTelConfig"]:
80 """
81 Get the OTel configuration from the observability manager.
82
83 Returns:
84 OTelConfig instance if observability is enabled, None otherwise
85 """
86 try:
87 manager = get_observability_instance().get_provider_manager()
88 if manager is None:
89 return None
90 return manager.config
91 except Exception:
92 return None
93
94
95async def record_operation_duration(
96 command_name: str,
97 duration_seconds: float,
98 server_address: Optional[str] = None,
99 server_port: Optional[int] = None,
100 db_namespace: Optional[str] = None,
101 error: Optional[Exception] = None,
102 is_blocking: Optional[bool] = None,
103 retry_attempts: Optional[int] = None,
104) -> None:
105 """
106 Record a Redis command execution duration.
107
108 This is an async-safe API that Redis async client code can call directly.
109 If observability is not enabled, this returns immediately with zero overhead.
110
111 Args:
112 command_name: Redis command name (e.g., 'GET', 'SET')
113 duration_seconds: Command execution time in seconds
114 server_address: Redis server address
115 server_port: Redis server port
116 db_namespace: Redis database index
117 error: Exception if command failed, None if successful
118 is_blocking: Whether the operation is a blocking command
119 retry_attempts: Number of retry attempts made
120
121 Example:
122 >>> start = time.monotonic()
123 >>> # ... execute command ...
124 >>> await record_operation_duration('SET', time.monotonic() - start, 'localhost', 6379, '0')
125 """
126 collector = _get_or_create_collector()
127 if collector is None:
128 return
129
130 try:
131 collector.record_operation_duration(
132 command_name=command_name,
133 duration_seconds=duration_seconds,
134 server_address=server_address,
135 server_port=server_port,
136 db_namespace=db_namespace,
137 error_type=error,
138 network_peer_address=server_address,
139 network_peer_port=server_port,
140 is_blocking=is_blocking,
141 retry_attempts=retry_attempts,
142 )
143 except Exception:
144 pass
145
146
147async def record_connection_create_time(
148 connection_pool: "ConnectionPool",
149 duration_seconds: float,
150) -> None:
151 """
152 Record connection creation time.
153
154 Args:
155 connection_pool: Connection pool implementation
156 duration_seconds: Time taken to create connection in seconds
157 """
158 collector = _get_or_create_collector()
159 if collector is None:
160 return
161
162 try:
163 collector.record_connection_create_time(
164 connection_pool=connection_pool,
165 duration_seconds=duration_seconds,
166 )
167 except Exception:
168 pass
169
170
171async def init_connection_count() -> None:
172 """
173 Initialize observable gauge for connection count metric.
174 """
175 collector = _get_or_create_collector()
176 if collector is None:
177 return
178
179 def observable_callback(__):
180 observables_registry = get_observables_registry_instance()
181 callbacks = observables_registry.get(CONNECTION_COUNT_REGISTRY_KEY)
182 observations = []
183
184 for callback in callbacks:
185 observations.extend(callback())
186
187 return observations
188
189 try:
190 collector.init_connection_count(
191 callback=observable_callback,
192 )
193 except Exception:
194 pass
195
196
197async def register_pools_connection_count(
198 connection_pools: List["ConnectionPool"],
199) -> None:
200 """
201 Add connection pools to connection count observable registry.
202 """
203 collector = _get_or_create_collector()
204 if collector is None:
205 return
206
207 try:
208 # Lazy import
209 from opentelemetry.metrics import Observation
210
211 def connection_count_callback():
212 observations = []
213 for connection_pool in connection_pools:
214 for count, attributes in connection_pool.get_connection_count():
215 observations.append(Observation(count, attributes=attributes))
216 return observations
217
218 observables_registry = get_observables_registry_instance()
219 observables_registry.register(
220 CONNECTION_COUNT_REGISTRY_KEY, connection_count_callback
221 )
222 except Exception:
223 pass
224
225
226async def record_connection_timeout(
227 pool_name: str,
228) -> None:
229 """
230 Record a connection timeout event.
231
232 Args:
233 pool_name: Connection pool identifier
234 """
235 collector = _get_or_create_collector()
236 if collector is None:
237 return
238
239 try:
240 collector.record_connection_timeout(
241 pool_name=pool_name,
242 )
243 except Exception:
244 pass
245
246
247async def record_connection_wait_time(
248 pool_name: str,
249 duration_seconds: float,
250) -> None:
251 """
252 Record time taken to obtain a connection from the pool.
253
254 Args:
255 pool_name: Connection pool identifier
256 duration_seconds: Wait time in seconds
257 """
258 collector = _get_or_create_collector()
259 if collector is None:
260 return
261
262 try:
263 collector.record_connection_wait_time(
264 pool_name=pool_name,
265 duration_seconds=duration_seconds,
266 )
267 except Exception:
268 pass
269
270
271async def record_connection_closed(
272 close_reason: Optional[CloseReason] = None,
273 error_type: Optional[Exception] = None,
274) -> None:
275 """
276 Record a connection closed event.
277
278 Args:
279 close_reason: Reason for closing (e.g. 'error', 'application_close')
280 error_type: Error type if closed due to error
281 """
282 collector = _get_or_create_collector()
283 if collector is None:
284 return
285
286 try:
287 collector.record_connection_closed(
288 close_reason=close_reason,
289 error_type=error_type,
290 )
291 except Exception:
292 pass
293
294
295async def record_connection_relaxed_timeout(
296 connection_name: str,
297 maint_notification: str,
298 relaxed: bool,
299) -> None:
300 """
301 Record a connection timeout relaxation event.
302
303 Args:
304 connection_name: Connection identifier
305 maint_notification: Maintenance notification type
306 relaxed: True to count up (relaxed), False to count down (unrelaxed)
307 """
308 collector = _get_or_create_collector()
309 if collector is None:
310 return
311
312 try:
313 collector.record_connection_relaxed_timeout(
314 connection_name=connection_name,
315 maint_notification=maint_notification,
316 relaxed=relaxed,
317 )
318 except Exception:
319 pass
320
321
322async def record_connection_handoff(
323 pool_name: str,
324) -> None:
325 """
326 Record a connection handoff event (e.g., after MOVING notification).
327
328 Args:
329 pool_name: Connection pool identifier
330 """
331 collector = _get_or_create_collector()
332 if collector is None:
333 return
334
335 try:
336 collector.record_connection_handoff(
337 pool_name=pool_name,
338 )
339 except Exception:
340 pass
341
342
343async def record_error_count(
344 server_address: str,
345 server_port: int,
346 network_peer_address: str,
347 network_peer_port: int,
348 error_type: Exception,
349 retry_attempts: int,
350 is_internal: bool = True,
351) -> None:
352 """
353 Record error count.
354
355 Args:
356 server_address: Server address
357 server_port: Server port
358 network_peer_address: Network peer address
359 network_peer_port: Network peer port
360 error_type: Error type (Exception)
361 retry_attempts: Retry attempts
362 is_internal: Whether the error is internal (e.g., timeout, network error)
363 """
364 collector = _get_or_create_collector()
365 if collector is None:
366 return
367
368 try:
369 collector.record_error_count(
370 server_address=server_address,
371 server_port=server_port,
372 network_peer_address=network_peer_address,
373 network_peer_port=network_peer_port,
374 error_type=error_type,
375 retry_attempts=retry_attempts,
376 is_internal=is_internal,
377 )
378 except Exception:
379 pass
380
381
382async def record_pubsub_message(
383 direction: PubSubDirection,
384 channel: Optional[str] = None,
385 sharded: Optional[bool] = None,
386) -> None:
387 """
388 Record a PubSub message (published or received).
389
390 Args:
391 direction: Message direction ('publish' or 'receive')
392 channel: Pub/Sub channel name
393 sharded: True if sharded Pub/Sub channel
394 """
395 collector = _get_or_create_collector()
396 if collector is None:
397 return
398
399 # Check if channel names should be hidden
400 effective_channel = channel
401 if channel is not None:
402 config = await _get_config()
403 if config is not None and config.hide_pubsub_channel_names:
404 effective_channel = None
405 else:
406 # Normalize bytes to str for OTel attributes
407 effective_channel = str_if_bytes(channel)
408
409 try:
410 collector.record_pubsub_message(
411 direction=direction,
412 channel=effective_channel,
413 sharded=sharded,
414 )
415 except Exception:
416 pass
417
418
419async def record_streaming_lag(
420 lag_seconds: float,
421 stream_name: Optional[str] = None,
422 consumer_group: Optional[str] = None,
423) -> None:
424 """
425 Record the lag of a streaming message.
426
427 Args:
428 lag_seconds: Lag in seconds
429 stream_name: Stream name
430 consumer_group: Consumer group name
431 """
432 collector = _get_or_create_collector()
433 if collector is None:
434 return
435
436 # Check if stream names should be hidden
437 effective_stream_name = stream_name
438 if stream_name is not None:
439 config = await _get_config()
440 if config is not None and config.hide_stream_names:
441 effective_stream_name = None
442
443 try:
444 collector.record_streaming_lag(
445 lag_seconds=lag_seconds,
446 stream_name=effective_stream_name,
447 consumer_group=consumer_group,
448 )
449 except Exception:
450 pass
451
452
453async def record_streaming_lag_from_response(
454 response,
455 consumer_group: Optional[str] = None,
456) -> None:
457 """
458 Record streaming lag from XREAD/XREADGROUP response.
459
460 Parses the response and calculates lag for each message based on message ID timestamp.
461
462 Args:
463 response: Response from XREAD/XREADGROUP command
464 consumer_group: Consumer group name (for XREADGROUP)
465 """
466 collector = _get_or_create_collector()
467 if collector is None:
468 return
469
470 if not response:
471 return
472
473 try:
474 now = datetime.now().timestamp()
475
476 # Check if stream names should be hidden
477 config = await _get_config()
478 hide_stream_names = config is not None and config.hide_stream_names
479
480 # RESP3 format: dict
481 if isinstance(response, dict):
482 for stream_name, stream_messages in response.items():
483 effective_stream_name = (
484 None if hide_stream_names else str_if_bytes(stream_name)
485 )
486 for messages in stream_messages:
487 for message in messages:
488 message_id, _ = message
489 message_id = str_if_bytes(message_id)
490 timestamp, _ = message_id.split("-")
491 # Ensure lag is non-negative (clock skew can cause negative values)
492 lag_seconds = max(0.0, now - int(timestamp) / 1000)
493
494 collector.record_streaming_lag(
495 lag_seconds=lag_seconds,
496 stream_name=effective_stream_name,
497 consumer_group=consumer_group,
498 )
499 else:
500 # RESP2 format: list
501 for stream_entry in response:
502 stream_name = str_if_bytes(stream_entry[0])
503 effective_stream_name = None if hide_stream_names else stream_name
504
505 for message in stream_entry[1]:
506 message_id, _ = message
507 message_id = str_if_bytes(message_id)
508 timestamp, _ = message_id.split("-")
509 # Ensure lag is non-negative (clock skew can cause negative values)
510 lag_seconds = max(0.0, now - int(timestamp) / 1000)
511
512 collector.record_streaming_lag(
513 lag_seconds=lag_seconds,
514 stream_name=effective_stream_name,
515 consumer_group=consumer_group,
516 )
517 except Exception:
518 pass
519
520
521async def record_maint_notification_count(
522 server_address: str,
523 server_port: int,
524 network_peer_address: str,
525 network_peer_port: int,
526 maint_notification: str,
527) -> None:
528 """
529 Record a maintenance notification count.
530
531 Args:
532 server_address: Server address
533 server_port: Server port
534 network_peer_address: Network peer address
535 network_peer_port: Network peer port
536 maint_notification: Maintenance notification type (e.g., 'MOVING', 'MIGRATING')
537 """
538 collector = _get_or_create_collector()
539 if collector is None:
540 return
541
542 try:
543 collector.record_maint_notification_count(
544 server_address=server_address,
545 server_port=server_port,
546 network_peer_address=network_peer_address,
547 network_peer_port=network_peer_port,
548 maint_notification=maint_notification,
549 )
550 except Exception:
551 pass
552
553
554async def record_geo_failover(
555 fail_from: "AsyncDatabase",
556 fail_to: "AsyncDatabase",
557 reason: GeoFailoverReason,
558) -> None:
559 """
560 Record a geo failover.
561
562 Args:
563 fail_from: Database failed from
564 fail_to: Database failed to
565 reason: Reason for the failover
566 """
567 collector = _get_or_create_collector()
568 if collector is None:
569 return
570
571 try:
572 collector.record_geo_failover(
573 fail_from=fail_from,
574 fail_to=fail_to,
575 reason=reason,
576 )
577 except Exception:
578 pass
579
580
581def reset_collector() -> None:
582 """
583 Reset the global async collector (used for testing or re-initialization).
584 """
585 global _async_metrics_collector
586 _async_metrics_collector = None
587
588
589async def is_enabled() -> bool:
590 """
591 Check if observability is enabled.
592
593 Returns:
594 True if metrics are being collected, False otherwise
595 """
596 collector = _get_or_create_collector()
597 return collector is not None