1from __future__ import annotations
2
3import datetime as dt
4import functools
5import inspect
6import os
7import sys
8import time as time_module
9import uuid
10from collections.abc import Generator
11from time import gmtime as orig_gmtime
12from time import struct_time
13from types import TracebackType
14from typing import Any
15from typing import Awaitable
16from typing import Callable
17from typing import Generator as TypingGenerator
18from typing import Tuple
19from typing import Type
20from typing import TypeVar
21from typing import Union
22from typing import cast
23from typing import overload
24from unittest import TestCase
25from unittest import mock
26
27import _time_machine
28from dateutil.parser import parse as parse_datetime
29
30# time.clock_gettime and time.CLOCK_REALTIME not always available
31# e.g. on builds against old macOS = official Python.org installer
32try:
33 from time import CLOCK_REALTIME
34except ImportError:
35 # Dummy value that won't compare equal to any value
36 CLOCK_REALTIME = sys.maxsize
37
38try:
39 from time import tzset
40
41 HAVE_TZSET = True
42except ImportError: # pragma: no cover
43 # Windows
44 HAVE_TZSET = False
45
46if sys.version_info >= (3, 9):
47 from zoneinfo import ZoneInfo
48
49 HAVE_ZONEINFO = True
50else:
51 try:
52 from backports.zoneinfo import ZoneInfo
53
54 HAVE_ZONEINFO = True
55 except ImportError: # pragma: no cover
56 HAVE_ZONEINFO = False
57
58
59try:
60 import pytest
61except ImportError: # pragma: no cover
62 HAVE_PYTEST = False
63else:
64 HAVE_PYTEST = True
65
66NANOSECONDS_PER_SECOND = 1_000_000_000
67
68# Windows' time epoch is not unix epoch but in 1601. This constant helps us
69# translate to it.
70_system_epoch = orig_gmtime(0)
71SYSTEM_EPOCH_TIMESTAMP_NS = int(
72 dt.datetime(
73 _system_epoch.tm_year,
74 _system_epoch.tm_mon,
75 _system_epoch.tm_mday,
76 _system_epoch.tm_hour,
77 _system_epoch.tm_min,
78 _system_epoch.tm_sec,
79 tzinfo=dt.timezone.utc,
80 ).timestamp()
81 * NANOSECONDS_PER_SECOND
82)
83
84DestinationBaseType = Union[
85 int,
86 float,
87 dt.datetime,
88 dt.timedelta,
89 dt.date,
90 str,
91]
92DestinationType = Union[
93 DestinationBaseType,
94 Callable[[], DestinationBaseType],
95 TypingGenerator[DestinationBaseType, None, None],
96]
97
98_F = TypeVar("_F", bound=Callable[..., Any])
99_AF = TypeVar("_AF", bound=Callable[..., Awaitable[Any]])
100TestCaseType = TypeVar("TestCaseType", bound=Type[TestCase])
101
102# copied from typeshed:
103_TimeTuple = Tuple[int, int, int, int, int, int, int, int, int]
104
105
106def extract_timestamp_tzname(
107 destination: DestinationType,
108) -> tuple[float, str | None]:
109 dest: DestinationBaseType
110 if isinstance(destination, Generator):
111 dest = next(destination)
112 elif callable(destination):
113 dest = destination()
114 else:
115 dest = destination
116
117 timestamp: float
118 tzname: str | None = None
119 if isinstance(dest, int):
120 timestamp = float(dest)
121 elif isinstance(dest, float):
122 timestamp = dest
123 elif isinstance(dest, dt.datetime):
124 if HAVE_ZONEINFO and isinstance(dest.tzinfo, ZoneInfo):
125 tzname = dest.tzinfo.key
126 if dest.tzinfo is None:
127 dest = dest.replace(tzinfo=dt.timezone.utc)
128 timestamp = dest.timestamp()
129 elif isinstance(dest, dt.timedelta):
130 timestamp = time_module.time() + dest.total_seconds()
131 elif isinstance(dest, dt.date):
132 timestamp = dt.datetime.combine(
133 dest, dt.time(0, 0), tzinfo=dt.timezone.utc
134 ).timestamp()
135 elif isinstance(dest, str):
136 timestamp = parse_datetime(dest).timestamp()
137 else:
138 raise TypeError(f"Unsupported destination {dest!r}")
139
140 return timestamp, tzname
141
142
143class Coordinates:
144 def __init__(
145 self,
146 destination_timestamp: float,
147 destination_tzname: str | None,
148 tick: bool,
149 ) -> None:
150 self._destination_timestamp_ns = int(
151 destination_timestamp * NANOSECONDS_PER_SECOND
152 )
153 self._destination_tzname = destination_tzname
154 self._tick = tick
155 self._requested = False
156
157 def time(self) -> float:
158 return self.time_ns() / NANOSECONDS_PER_SECOND
159
160 def time_ns(self) -> int:
161 if not self._tick:
162 return self._destination_timestamp_ns
163
164 base = SYSTEM_EPOCH_TIMESTAMP_NS + self._destination_timestamp_ns
165 now_ns: int = _time_machine.original_time_ns()
166
167 if not self._requested:
168 self._requested = True
169 self._real_start_timestamp_ns = now_ns
170 return base
171
172 return base + (now_ns - self._real_start_timestamp_ns)
173
174 def shift(self, delta: dt.timedelta | int | float) -> None:
175 if isinstance(delta, dt.timedelta):
176 total_seconds = delta.total_seconds()
177 elif isinstance(delta, (int, float)):
178 total_seconds = delta
179 else:
180 raise TypeError(f"Unsupported type for delta argument: {delta!r}")
181
182 self._destination_timestamp_ns += int(total_seconds * NANOSECONDS_PER_SECOND)
183
184 def move_to(
185 self,
186 destination: DestinationType,
187 tick: bool | None = None,
188 ) -> None:
189 self._stop()
190 timestamp, self._destination_tzname = extract_timestamp_tzname(destination)
191 self._destination_timestamp_ns = int(timestamp * NANOSECONDS_PER_SECOND)
192 self._requested = False
193 self._start()
194 if tick is not None:
195 self._tick = tick
196
197 def _start(self) -> None:
198 if HAVE_TZSET and self._destination_tzname is not None:
199 self._orig_tz = os.environ.get("TZ")
200 os.environ["TZ"] = self._destination_tzname
201 tzset()
202
203 def _stop(self) -> None:
204 if HAVE_TZSET and self._destination_tzname is not None:
205 if self._orig_tz is None:
206 del os.environ["TZ"]
207 else:
208 os.environ["TZ"] = self._orig_tz
209 tzset()
210
211
212coordinates_stack: list[Coordinates] = []
213
214# During time travel, patch the uuid module's time-based generation function to
215# None, which makes it use time.time(). Otherwise it makes a system call to
216# find the current datetime. The time it finds is stored in generated UUID1
217# values.
218uuid_generate_time_attr = "_generate_time_safe"
219uuid_generate_time_patcher = mock.patch.object(uuid, uuid_generate_time_attr, new=None)
220uuid_uuid_create_patcher = mock.patch.object(uuid, "_UuidCreate", new=None)
221
222
223class travel:
224 def __init__(self, destination: DestinationType, *, tick: bool = True) -> None:
225 self.destination_timestamp, self.destination_tzname = extract_timestamp_tzname(
226 destination
227 )
228 self.tick = tick
229
230 def start(self) -> Coordinates:
231 global coordinates_stack
232
233 _time_machine.patch_if_needed()
234
235 if not coordinates_stack:
236 if sys.version_info < (3, 9):
237 # We need to cause the functions to be loaded before we patch
238 # them out, which is done by this internal function before:
239 # https://github.com/python/cpython/pull/19948
240 uuid._load_system_functions()
241 uuid_generate_time_patcher.start()
242 uuid_uuid_create_patcher.start()
243
244 coordinates = Coordinates(
245 destination_timestamp=self.destination_timestamp,
246 destination_tzname=self.destination_tzname,
247 tick=self.tick,
248 )
249 coordinates_stack.append(coordinates)
250 coordinates._start()
251
252 return coordinates
253
254 def stop(self) -> None:
255 global coordinates_stack
256 coordinates_stack.pop()._stop()
257
258 if not coordinates_stack:
259 uuid_generate_time_patcher.stop()
260 uuid_uuid_create_patcher.stop()
261
262 def __enter__(self) -> Coordinates:
263 return self.start()
264
265 def __exit__(
266 self,
267 exc_type: type[BaseException] | None,
268 exc_val: BaseException | None,
269 exc_tb: TracebackType | None,
270 ) -> None:
271 self.stop()
272
273 @overload
274 def __call__(self, wrapped: TestCaseType) -> TestCaseType: # pragma: no cover
275 ...
276
277 @overload
278 def __call__(self, wrapped: _AF) -> _AF: # pragma: no cover
279 ...
280
281 @overload
282 def __call__(self, wrapped: _F) -> _F: # pragma: no cover
283 ...
284
285 # 'Any' below is workaround for Mypy error:
286 # Overloaded function implementation does not accept all possible arguments
287 # of signature
288 def __call__(
289 self, wrapped: TestCaseType | _AF | _F | Any
290 ) -> TestCaseType | _AF | _F | Any:
291 if isinstance(wrapped, type):
292 # Class decorator
293 if not issubclass(wrapped, TestCase):
294 raise TypeError("Can only decorate unittest.TestCase subclasses.")
295
296 # Modify the setUpClass method
297 orig_setUpClass = wrapped.setUpClass.__func__ # type: ignore[attr-defined]
298
299 @functools.wraps(orig_setUpClass)
300 def setUpClass(cls: type[TestCase]) -> None:
301 self.__enter__()
302 try:
303 orig_setUpClass(cls)
304 except Exception:
305 self.__exit__(*sys.exc_info())
306 raise
307
308 wrapped.setUpClass = classmethod(setUpClass) # type: ignore[assignment]
309
310 orig_tearDownClass = (
311 wrapped.tearDownClass.__func__ # type: ignore[attr-defined]
312 )
313
314 @functools.wraps(orig_tearDownClass)
315 def tearDownClass(cls: type[TestCase]) -> None:
316 orig_tearDownClass(cls)
317 self.__exit__(None, None, None)
318
319 wrapped.tearDownClass = classmethod( # type: ignore[assignment]
320 tearDownClass
321 )
322 return cast(TestCaseType, wrapped)
323 elif inspect.iscoroutinefunction(wrapped):
324
325 @functools.wraps(wrapped)
326 async def wrapper(*args: Any, **kwargs: Any) -> Any:
327 with self:
328 return await wrapped(*args, **kwargs)
329
330 return cast(_AF, wrapper)
331 else:
332 assert callable(wrapped)
333
334 @functools.wraps(wrapped)
335 def wrapper(*args: Any, **kwargs: Any) -> Any:
336 with self:
337 return wrapped(*args, **kwargs)
338
339 return cast(_F, wrapper)
340
341
342# datetime module
343
344
345def now(tz: dt.tzinfo | None = None) -> dt.datetime:
346 if not coordinates_stack:
347 result: dt.datetime = _time_machine.original_now(tz)
348 return result
349 return dt.datetime.fromtimestamp(time(), tz)
350
351
352def utcnow() -> dt.datetime:
353 if not coordinates_stack:
354 result: dt.datetime = _time_machine.original_utcnow()
355 return result
356 return dt.datetime.fromtimestamp(time(), dt.timezone.utc).replace(tzinfo=None)
357
358
359# time module
360
361
362def clock_gettime(clk_id: int) -> float:
363 if not coordinates_stack or clk_id != CLOCK_REALTIME:
364 result: float = _time_machine.original_clock_gettime(clk_id)
365 return result
366 return time()
367
368
369def clock_gettime_ns(clk_id: int) -> int:
370 if not coordinates_stack or clk_id != CLOCK_REALTIME:
371 result: int = _time_machine.original_clock_gettime_ns(clk_id)
372 return result
373 return time_ns()
374
375
376def gmtime(secs: float | None = None) -> struct_time:
377 result: struct_time
378 if not coordinates_stack or secs is not None:
379 result = _time_machine.original_gmtime(secs)
380 else:
381 result = _time_machine.original_gmtime(coordinates_stack[-1].time())
382 return result
383
384
385def localtime(secs: float | None = None) -> struct_time:
386 result: struct_time
387 if not coordinates_stack or secs is not None:
388 result = _time_machine.original_localtime(secs)
389 else:
390 result = _time_machine.original_localtime(coordinates_stack[-1].time())
391 return result
392
393
394def strftime(format: str, t: _TimeTuple | struct_time | None = None) -> str:
395 result: str
396 if t is not None:
397 result = _time_machine.original_strftime(format, t)
398 elif not coordinates_stack:
399 result = _time_machine.original_strftime(format)
400 else:
401 result = _time_machine.original_strftime(format, localtime())
402 return result
403
404
405def time() -> float:
406 if not coordinates_stack:
407 result: float = _time_machine.original_time()
408 return result
409 return coordinates_stack[-1].time()
410
411
412def time_ns() -> int:
413 if not coordinates_stack:
414 result: int = _time_machine.original_time_ns()
415 return result
416 return coordinates_stack[-1].time_ns()
417
418
419# pytest plugin
420
421if HAVE_PYTEST: # pragma: no branch
422
423 class TimeMachineFixture:
424 traveller: travel | None
425 coordinates: Coordinates | None
426
427 def __init__(self) -> None:
428 self.traveller = None
429 self.coordinates = None
430
431 def move_to(
432 self,
433 destination: DestinationType,
434 tick: bool | None = None,
435 ) -> None:
436 if self.traveller is None:
437 if tick is None:
438 tick = True
439 self.traveller = travel(destination, tick=tick)
440 self.coordinates = self.traveller.start()
441 else:
442 assert self.coordinates is not None
443 self.coordinates.move_to(destination, tick=tick)
444
445 def shift(self, delta: dt.timedelta | int | float) -> None:
446 if self.traveller is None:
447 raise RuntimeError(
448 "Initialize time_machine with move_to() before using shift()."
449 )
450 assert self.coordinates is not None
451 self.coordinates.shift(delta=delta)
452
453 def stop(self) -> None:
454 if self.traveller is not None:
455 self.traveller.stop()
456
457 @pytest.fixture(name="time_machine")
458 def time_machine_fixture() -> TypingGenerator[TimeMachineFixture, None, None]:
459 fixture = TimeMachineFixture()
460 yield fixture
461 fixture.stop()
462
463
464# escape hatch
465
466
467class _EscapeHatchDatetimeDatetime:
468 def now(self, tz: dt.tzinfo | None = None) -> dt.datetime:
469 result: dt.datetime = _time_machine.original_now(tz)
470 return result
471
472 def utcnow(self) -> dt.datetime:
473 result: dt.datetime = _time_machine.original_utcnow()
474 return result
475
476
477class _EscapeHatchDatetime:
478 def __init__(self) -> None:
479 self.datetime = _EscapeHatchDatetimeDatetime()
480
481
482class _EscapeHatchTime:
483 def clock_gettime(self, clk_id: int) -> float:
484 result: float = _time_machine.original_clock_gettime(clk_id)
485 return result
486
487 def clock_gettime_ns(self, clk_id: int) -> int:
488 result: int = _time_machine.original_clock_gettime_ns(clk_id)
489 return result
490
491 def gmtime(self, secs: float | None = None) -> struct_time:
492 result: struct_time = _time_machine.original_gmtime(secs)
493 return result
494
495 def localtime(self, secs: float | None = None) -> struct_time:
496 result: struct_time = _time_machine.original_localtime(secs)
497 return result
498
499 def monotonic(self) -> float:
500 result: float = _time_machine.original_monotonic()
501 return result
502
503 def monotonic_ns(self) -> int:
504 result: int = _time_machine.original_monotonic_ns()
505 return result
506
507 def strftime(self, format: str, t: _TimeTuple | struct_time | None = None) -> str:
508 result: str
509 if t is not None:
510 result = _time_machine.original_strftime(format, t)
511 else:
512 result = _time_machine.original_strftime(format)
513 return result
514
515 def time(self) -> float:
516 result: float = _time_machine.original_time()
517 return result
518
519 def time_ns(self) -> int:
520 result: int = _time_machine.original_time_ns()
521 return result
522
523
524class _EscapeHatch:
525 def __init__(self) -> None:
526 self.datetime = _EscapeHatchDatetime()
527 self.time = _EscapeHatchTime()
528
529 def is_travelling(self) -> bool:
530 return bool(coordinates_stack)
531
532
533escape_hatch = _EscapeHatch()