1"""
2Async-compatible API for recording observability metrics.
3
4This module provides an async-safe interface for Redis async client code to record
5metrics without needing to know about OpenTelemetry internals. It reuses the same
6RedisMetricsCollector and configuration as the sync recorder.
7
8Usage in Redis async client code:
9 from redis.asyncio.observability.recorder import record_operation_duration
10
11 start_time = time.monotonic()
12 # ... execute Redis command ...
13 await record_operation_duration(
14 command_name='SET',
15 duration_seconds=time.monotonic() - start_time,
16 server_address='localhost',
17 server_port=6379,
18 db_namespace='0',
19 error=None
20 )
21"""
22
23from datetime import datetime
24from typing import TYPE_CHECKING, List, Optional
25
26from redis.observability.attributes import (
27 ConnectionState,
28 GeoFailoverReason,
29 PubSubDirection,
30)
31from redis.observability.metrics import CloseReason, RedisMetricsCollector
32from redis.observability.providers import get_observability_instance
33from redis.observability.registry import get_observables_registry_instance
34from redis.utils import deprecated_function, str_if_bytes
35
36if TYPE_CHECKING:
37 from redis.asyncio.connection import ConnectionPool
38 from redis.asyncio.multidb.database import AsyncDatabase
39 from redis.observability.config import OTelConfig
40
41# Global metrics collector instance (lazy-initialized)
42_async_metrics_collector: Optional[RedisMetricsCollector] = None
43
44CONNECTION_COUNT_REGISTRY_KEY = "connection_count"
45
46
47def _get_or_create_collector() -> Optional[RedisMetricsCollector]:
48 """
49 Get or create the global metrics collector.
50
51 Returns:
52 RedisMetricsCollector instance if observability is enabled, None otherwise
53 """
54 global _async_metrics_collector
55
56 if _async_metrics_collector is not None:
57 return _async_metrics_collector
58
59 try:
60 manager = get_observability_instance().get_provider_manager()
61 if manager is None or not manager.config.enabled_telemetry:
62 return None
63
64 # Get meter from the global MeterProvider
65 meter = manager.get_meter_provider().get_meter(
66 RedisMetricsCollector.METER_NAME, RedisMetricsCollector.METER_VERSION
67 )
68
69 _async_metrics_collector = RedisMetricsCollector(meter, manager.config)
70 return _async_metrics_collector
71
72 except ImportError:
73 # Observability module not available
74 return None
75 except Exception:
76 # Any other error - don't break Redis operations
77 return None
78
79
80async def _get_config() -> Optional["OTelConfig"]:
81 """
82 Get the OTel configuration from the observability manager.
83
84 Returns:
85 OTelConfig instance if observability is enabled, None otherwise
86 """
87 try:
88 manager = get_observability_instance().get_provider_manager()
89 if manager is None:
90 return None
91 return manager.config
92 except Exception:
93 return None
94
95
96async def record_operation_duration(
97 command_name: str,
98 duration_seconds: float,
99 server_address: Optional[str] = None,
100 server_port: Optional[int] = None,
101 db_namespace: Optional[str] = None,
102 error: Optional[Exception] = None,
103 is_blocking: Optional[bool] = None,
104 retry_attempts: Optional[int] = None,
105) -> None:
106 """
107 Record a Redis command execution duration.
108
109 This is an async-safe API that Redis async client code can call directly.
110 If observability is not enabled, this returns immediately with zero overhead.
111
112 Args:
113 command_name: Redis command name (e.g., 'GET', 'SET')
114 duration_seconds: Command execution time in seconds
115 server_address: Redis server address
116 server_port: Redis server port
117 db_namespace: Redis database index
118 error: Exception if command failed, None if successful
119 is_blocking: Whether the operation is a blocking command
120 retry_attempts: Number of retry attempts made
121
122 Example:
123 >>> start = time.monotonic()
124 >>> # ... execute command ...
125 >>> await record_operation_duration('SET', time.monotonic() - start, 'localhost', 6379, '0')
126 """
127 collector = _get_or_create_collector()
128 if collector is None:
129 return
130
131 try:
132 collector.record_operation_duration(
133 command_name=command_name,
134 duration_seconds=duration_seconds,
135 server_address=server_address,
136 server_port=server_port,
137 db_namespace=db_namespace,
138 error_type=error,
139 network_peer_address=server_address,
140 network_peer_port=server_port,
141 is_blocking=is_blocking,
142 retry_attempts=retry_attempts,
143 )
144 except Exception:
145 pass
146
147
148async def record_connection_create_time(
149 connection_pool: "ConnectionPool",
150 duration_seconds: float,
151) -> None:
152 """
153 Record connection creation time.
154
155 Args:
156 connection_pool: Connection pool implementation
157 duration_seconds: Time taken to create connection in seconds
158 """
159 collector = _get_or_create_collector()
160 if collector is None:
161 return
162
163 try:
164 collector.record_connection_create_time(
165 connection_pool=connection_pool,
166 duration_seconds=duration_seconds,
167 )
168 except Exception:
169 pass
170
171
172async def record_connection_count(
173 pool_name: str,
174 connection_state: ConnectionState,
175 counter: int = 1,
176) -> None:
177 """
178 Record a connection count change for a single state.
179
180 Args:
181 pool_name: Connection pool identifier
182 connection_state: State to update (IDLE or USED)
183 counter: Number to add (positive) or subtract (negative)
184 """
185 collector = _get_or_create_collector()
186 if collector is None:
187 return
188
189 try:
190 collector.record_connection_count(
191 pool_name=pool_name,
192 connection_state=connection_state,
193 counter=counter,
194 )
195 except Exception:
196 pass
197
198
199@deprecated_function(
200 reason="Connection count is now tracked via record_connection_count(). "
201 "This functionality will be removed in the next major version",
202 version="7.4.0",
203)
204async def init_connection_count() -> None:
205 """
206 Initialize observable gauge for connection count metric.
207 """
208 collector = _get_or_create_collector()
209 if collector is None:
210 return
211
212 def observable_callback(__):
213 observables_registry = get_observables_registry_instance()
214 callbacks = observables_registry.get(CONNECTION_COUNT_REGISTRY_KEY)
215 observations = []
216
217 for callback in callbacks:
218 observations.extend(callback())
219
220 return observations
221
222 try:
223 collector.init_connection_count(
224 callback=observable_callback,
225 )
226 except Exception:
227 pass
228
229
230@deprecated_function(
231 reason="Connection count is now tracked via record_connection_count(). "
232 "This functionality will be removed in the next major version",
233 version="7.4.0",
234)
235async def register_pools_connection_count(
236 connection_pools: List["ConnectionPool"],
237) -> None:
238 """
239 Add connection pools to connection count observable registry.
240 """
241 collector = _get_or_create_collector()
242 if collector is None:
243 return
244
245 try:
246 # Lazy import
247 from opentelemetry.metrics import Observation
248
249 def connection_count_callback():
250 observations = []
251 for connection_pool in connection_pools:
252 for count, attributes in connection_pool.get_connection_count():
253 observations.append(Observation(count, attributes=attributes))
254 return observations
255
256 observables_registry = get_observables_registry_instance()
257 observables_registry.register(
258 CONNECTION_COUNT_REGISTRY_KEY, connection_count_callback
259 )
260 except Exception:
261 pass
262
263
264async def record_connection_timeout(
265 pool_name: str,
266) -> None:
267 """
268 Record a connection timeout event.
269
270 Args:
271 pool_name: Connection pool identifier
272 """
273 collector = _get_or_create_collector()
274 if collector is None:
275 return
276
277 try:
278 collector.record_connection_timeout(
279 pool_name=pool_name,
280 )
281 except Exception:
282 pass
283
284
285async def record_connection_wait_time(
286 pool_name: str,
287 duration_seconds: float,
288) -> None:
289 """
290 Record time taken to obtain a connection from the pool.
291
292 Args:
293 pool_name: Connection pool identifier
294 duration_seconds: Wait time in seconds
295 """
296 collector = _get_or_create_collector()
297 if collector is None:
298 return
299
300 try:
301 collector.record_connection_wait_time(
302 pool_name=pool_name,
303 duration_seconds=duration_seconds,
304 )
305 except Exception:
306 pass
307
308
309async def record_connection_closed(
310 close_reason: Optional[CloseReason] = None,
311 error_type: Optional[Exception] = None,
312) -> None:
313 """
314 Record a connection closed event.
315
316 Args:
317 close_reason: Reason for closing (e.g. 'error', 'application_close')
318 error_type: Error type if closed due to error
319 """
320 collector = _get_or_create_collector()
321 if collector is None:
322 return
323
324 try:
325 collector.record_connection_closed(
326 close_reason=close_reason,
327 error_type=error_type,
328 )
329 except Exception:
330 pass
331
332
333async def record_connection_relaxed_timeout(
334 connection_name: str,
335 maint_notification: str,
336 relaxed: bool,
337) -> None:
338 """
339 Record a connection timeout relaxation event.
340
341 Args:
342 connection_name: Connection identifier (pool name)
343 maint_notification: Maintenance notification type
344 relaxed: True to count up (relaxed), False to count down (unrelaxed)
345 """
346 collector = _get_or_create_collector()
347 if collector is None:
348 return
349
350 try:
351 collector.record_connection_relaxed_timeout(
352 connection_name=connection_name,
353 maint_notification=maint_notification,
354 relaxed=relaxed,
355 )
356 except Exception:
357 pass
358
359
360async def record_connection_handoff(
361 pool_name: str,
362) -> None:
363 """
364 Record a connection handoff event (e.g., after MOVING notification).
365
366 Args:
367 pool_name: Connection pool identifier
368 """
369 collector = _get_or_create_collector()
370 if collector is None:
371 return
372
373 try:
374 collector.record_connection_handoff(
375 pool_name=pool_name,
376 )
377 except Exception:
378 pass
379
380
381async def record_error_count(
382 server_address: str,
383 server_port: int,
384 network_peer_address: str,
385 network_peer_port: int,
386 error_type: Exception,
387 retry_attempts: int,
388 is_internal: bool = True,
389) -> None:
390 """
391 Record error count.
392
393 Args:
394 server_address: Server address
395 server_port: Server port
396 network_peer_address: Network peer address
397 network_peer_port: Network peer port
398 error_type: Error type (Exception)
399 retry_attempts: Retry attempts
400 is_internal: Whether the error is internal (e.g., timeout, network error)
401 """
402 collector = _get_or_create_collector()
403 if collector is None:
404 return
405
406 try:
407 collector.record_error_count(
408 server_address=server_address,
409 server_port=server_port,
410 network_peer_address=network_peer_address,
411 network_peer_port=network_peer_port,
412 error_type=error_type,
413 retry_attempts=retry_attempts,
414 is_internal=is_internal,
415 )
416 except Exception:
417 pass
418
419
420async def record_pubsub_message(
421 direction: PubSubDirection,
422 channel: Optional[str] = None,
423 sharded: Optional[bool] = None,
424) -> None:
425 """
426 Record a PubSub message (published or received).
427
428 Args:
429 direction: Message direction ('publish' or 'receive')
430 channel: Pub/Sub channel name
431 sharded: True if sharded Pub/Sub channel
432 """
433 collector = _get_or_create_collector()
434 if collector is None:
435 return
436
437 # Check if channel names should be hidden
438 effective_channel = channel
439 if channel is not None:
440 config = await _get_config()
441 if config is not None and config.hide_pubsub_channel_names:
442 effective_channel = None
443 else:
444 # Normalize bytes to str for OTel attributes
445 effective_channel = str_if_bytes(channel)
446
447 try:
448 collector.record_pubsub_message(
449 direction=direction,
450 channel=effective_channel,
451 sharded=sharded,
452 )
453 except Exception:
454 pass
455
456
457async def record_streaming_lag(
458 lag_seconds: float,
459 stream_name: Optional[str] = None,
460 consumer_group: Optional[str] = None,
461) -> None:
462 """
463 Record the lag of a streaming message.
464
465 Args:
466 lag_seconds: Lag in seconds
467 stream_name: Stream name
468 consumer_group: Consumer group name
469 """
470 collector = _get_or_create_collector()
471 if collector is None:
472 return
473
474 # Check if stream names should be hidden
475 effective_stream_name = stream_name
476 if stream_name is not None:
477 config = await _get_config()
478 if config is not None and config.hide_stream_names:
479 effective_stream_name = None
480
481 try:
482 collector.record_streaming_lag(
483 lag_seconds=lag_seconds,
484 stream_name=effective_stream_name,
485 consumer_group=consumer_group,
486 )
487 except Exception:
488 pass
489
490
491async def record_streaming_lag_from_response(
492 response,
493 consumer_group: Optional[str] = None,
494) -> None:
495 """
496 Record streaming lag from XREAD/XREADGROUP response.
497
498 Parses the response and calculates lag for each message based on message ID timestamp.
499
500 Args:
501 response: Response from XREAD/XREADGROUP command
502 consumer_group: Consumer group name (for XREADGROUP)
503 """
504 collector = _get_or_create_collector()
505 if collector is None:
506 return
507
508 if not response:
509 return
510
511 try:
512 now = datetime.now().timestamp()
513
514 # Check if stream names should be hidden
515 config = await _get_config()
516 hide_stream_names = config is not None and config.hide_stream_names
517
518 # RESP3 format: dict
519 if isinstance(response, dict):
520 for stream_name, stream_messages in response.items():
521 effective_stream_name = (
522 None if hide_stream_names else str_if_bytes(stream_name)
523 )
524 for messages in stream_messages:
525 for message in messages:
526 message_id, _ = message
527 message_id = str_if_bytes(message_id)
528 timestamp, _ = message_id.split("-")
529 # Ensure lag is non-negative (clock skew can cause negative values)
530 lag_seconds = max(0.0, now - int(timestamp) / 1000)
531
532 collector.record_streaming_lag(
533 lag_seconds=lag_seconds,
534 stream_name=effective_stream_name,
535 consumer_group=consumer_group,
536 )
537 else:
538 # RESP2 format: list
539 for stream_entry in response:
540 stream_name = str_if_bytes(stream_entry[0])
541 effective_stream_name = None if hide_stream_names else stream_name
542
543 for message in stream_entry[1]:
544 message_id, _ = message
545 message_id = str_if_bytes(message_id)
546 timestamp, _ = message_id.split("-")
547 # Ensure lag is non-negative (clock skew can cause negative values)
548 lag_seconds = max(0.0, now - int(timestamp) / 1000)
549
550 collector.record_streaming_lag(
551 lag_seconds=lag_seconds,
552 stream_name=effective_stream_name,
553 consumer_group=consumer_group,
554 )
555 except Exception:
556 pass
557
558
559async def record_maint_notification_count(
560 server_address: str,
561 server_port: int,
562 network_peer_address: str,
563 network_peer_port: int,
564 maint_notification: str,
565) -> None:
566 """
567 Record a maintenance notification count.
568
569 Args:
570 server_address: Server address
571 server_port: Server port
572 network_peer_address: Network peer address
573 network_peer_port: Network peer port
574 maint_notification: Maintenance notification type (e.g., 'MOVING', 'MIGRATING')
575 """
576 collector = _get_or_create_collector()
577 if collector is None:
578 return
579
580 try:
581 collector.record_maint_notification_count(
582 server_address=server_address,
583 server_port=server_port,
584 network_peer_address=network_peer_address,
585 network_peer_port=network_peer_port,
586 maint_notification=maint_notification,
587 )
588 except Exception:
589 pass
590
591
592async def record_geo_failover(
593 fail_from: "AsyncDatabase",
594 fail_to: "AsyncDatabase",
595 reason: GeoFailoverReason,
596) -> None:
597 """
598 Record a geo failover.
599
600 Args:
601 fail_from: Database failed from
602 fail_to: Database failed to
603 reason: Reason for the failover
604 """
605 collector = _get_or_create_collector()
606 if collector is None:
607 return
608
609 try:
610 collector.record_geo_failover(
611 fail_from=fail_from,
612 fail_to=fail_to,
613 reason=reason,
614 )
615 except Exception:
616 pass
617
618
619def reset_collector() -> None:
620 """
621 Reset the global async collector (used for testing or re-initialization).
622 """
623 global _async_metrics_collector
624 _async_metrics_collector = None
625
626
627async def is_enabled() -> bool:
628 """
629 Check if observability is enabled.
630
631 Returns:
632 True if metrics are being collected, False otherwise
633 """
634 collector = _get_or_create_collector()
635 return collector is not None