Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/httpcore/_synchronization.py: 32%

100 statements  

« prev     ^ index     » next       coverage.py v7.2.2, created at 2023-03-26 06:12 +0000

1import threading 

2from types import TracebackType 

3from typing import Optional, Type 

4 

5import sniffio 

6 

7from ._exceptions import ExceptionMapping, PoolTimeout, map_exceptions 

8 

9# Our async synchronization primatives use either 'anyio' or 'trio' depending 

10# on if they're running under asyncio or trio. 

11# 

12# We take care to only lazily import whichever of these two we need. 

13 

14 

15class AsyncLock: 

16 def __init__(self) -> None: 

17 self._backend = "" 

18 

19 def setup(self) -> None: 

20 """ 

21 Detect if we're running under 'asyncio' or 'trio' and create 

22 a lock with the correct implementation. 

23 """ 

24 self._backend = sniffio.current_async_library() 

25 if self._backend == "trio": 

26 import trio 

27 

28 self._trio_lock = trio.Lock() 

29 else: 

30 import anyio 

31 

32 self._anyio_lock = anyio.Lock() 

33 

34 async def __aenter__(self) -> "AsyncLock": 

35 if not self._backend: 

36 self.setup() 

37 

38 if self._backend == "trio": 

39 await self._trio_lock.acquire() 

40 else: 

41 await self._anyio_lock.acquire() 

42 

43 return self 

44 

45 async def __aexit__( 

46 self, 

47 exc_type: Optional[Type[BaseException]] = None, 

48 exc_value: Optional[BaseException] = None, 

49 traceback: Optional[TracebackType] = None, 

50 ) -> None: 

51 if self._backend == "trio": 

52 self._trio_lock.release() 

53 else: 

54 self._anyio_lock.release() 

55 

56 

57class AsyncEvent: 

58 def __init__(self) -> None: 

59 self._backend = "" 

60 

61 def setup(self) -> None: 

62 """ 

63 Detect if we're running under 'asyncio' or 'trio' and create 

64 a lock with the correct implementation. 

65 """ 

66 self._backend = sniffio.current_async_library() 

67 if self._backend == "trio": 

68 import trio 

69 

70 self._trio_event = trio.Event() 

71 else: 

72 import anyio 

73 

74 self._anyio_event = anyio.Event() 

75 

76 def set(self) -> None: 

77 if not self._backend: 

78 self.setup() 

79 

80 if self._backend == "trio": 

81 self._trio_event.set() 

82 else: 

83 self._anyio_event.set() 

84 

85 async def wait(self, timeout: Optional[float] = None) -> None: 

86 if not self._backend: 

87 self.setup() 

88 

89 if self._backend == "trio": 

90 import trio 

91 

92 trio_exc_map: ExceptionMapping = {trio.TooSlowError: PoolTimeout} 

93 timeout_or_inf = float("inf") if timeout is None else timeout 

94 with map_exceptions(trio_exc_map): 

95 with trio.fail_after(timeout_or_inf): 

96 await self._trio_event.wait() 

97 else: 

98 import anyio 

99 

100 anyio_exc_map: ExceptionMapping = {TimeoutError: PoolTimeout} 

101 with map_exceptions(anyio_exc_map): 

102 with anyio.fail_after(timeout): 

103 await self._anyio_event.wait() 

104 

105 

106class AsyncSemaphore: 

107 def __init__(self, bound: int) -> None: 

108 self._bound = bound 

109 self._backend = "" 

110 

111 def setup(self) -> None: 

112 """ 

113 Detect if we're running under 'asyncio' or 'trio' and create 

114 a semaphore with the correct implementation. 

115 """ 

116 self._backend = sniffio.current_async_library() 

117 if self._backend == "trio": 

118 import trio 

119 

120 self._trio_semaphore = trio.Semaphore( 

121 initial_value=self._bound, max_value=self._bound 

122 ) 

123 else: 

124 import anyio 

125 

126 self._anyio_semaphore = anyio.Semaphore( 

127 initial_value=self._bound, max_value=self._bound 

128 ) 

129 

130 async def acquire(self) -> None: 

131 if not self._backend: 

132 self.setup() 

133 

134 if self._backend == "trio": 

135 await self._trio_semaphore.acquire() 

136 else: 

137 await self._anyio_semaphore.acquire() 

138 

139 async def release(self) -> None: 

140 if self._backend == "trio": 

141 self._trio_semaphore.release() 

142 else: 

143 self._anyio_semaphore.release() 

144 

145 

146# Our thread-based synchronization primitives... 

147 

148 

149class Lock: 

150 def __init__(self) -> None: 

151 self._lock = threading.Lock() 

152 

153 def __enter__(self) -> "Lock": 

154 self._lock.acquire() 

155 return self 

156 

157 def __exit__( 

158 self, 

159 exc_type: Optional[Type[BaseException]] = None, 

160 exc_value: Optional[BaseException] = None, 

161 traceback: Optional[TracebackType] = None, 

162 ) -> None: 

163 self._lock.release() 

164 

165 

166class Event: 

167 def __init__(self) -> None: 

168 self._event = threading.Event() 

169 

170 def set(self) -> None: 

171 self._event.set() 

172 

173 def wait(self, timeout: Optional[float] = None) -> None: 

174 if not self._event.wait(timeout=timeout): 

175 raise PoolTimeout() # pragma: nocover 

176 

177 

178class Semaphore: 

179 def __init__(self, bound: int) -> None: 

180 self._semaphore = threading.Semaphore(value=bound) 

181 

182 def acquire(self) -> None: 

183 self._semaphore.acquire() 

184 

185 def release(self) -> None: 

186 self._semaphore.release()