1# Copyright 2020 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for applying Google Cloud Firestore changes in a transaction."""
16from __future__ import annotations
17
18from typing import (
19 TYPE_CHECKING,
20 Any,
21 AsyncGenerator,
22 Awaitable,
23 Callable,
24 Optional,
25)
26
27from google.api_core import exceptions, gapic_v1
28from google.api_core import retry_async as retries
29
30from google.cloud.firestore_v1 import _helpers, async_batch
31from google.cloud.firestore_v1.async_document import AsyncDocumentReference
32from google.cloud.firestore_v1.async_query import AsyncQuery
33from google.cloud.firestore_v1.base_transaction import (
34 _CANT_BEGIN,
35 _CANT_COMMIT,
36 _CANT_ROLLBACK,
37 _EXCEED_ATTEMPTS_TEMPLATE,
38 _WRITE_READ_ONLY,
39 MAX_ATTEMPTS,
40 BaseTransaction,
41 _BaseTransactional,
42)
43
44# Types needed only for Type Hints
45if TYPE_CHECKING: # pragma: NO COVER
46 import datetime
47 from typing_extensions import TypeVar, ParamSpec, Concatenate
48
49 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
50 from google.cloud.firestore_v1.base_document import DocumentSnapshot
51 from google.cloud.firestore_v1.query_profile import ExplainOptions
52
53 T = TypeVar("T")
54 P = ParamSpec("P")
55
56
57class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
58 """Accumulate read-and-write operations to be sent in a transaction.
59
60 Args:
61 client (:class:`~google.cloud.firestore_v1.async_client.AsyncClient`):
62 The client that created this transaction.
63 max_attempts (Optional[int]): The maximum number of attempts for
64 the transaction (i.e. allowing retries). Defaults to
65 :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`.
66 read_only (Optional[bool]): Flag indicating if the transaction
67 should be read-only or should allow writes. Defaults to
68 :data:`False`.
69 """
70
71 def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None:
72 super(AsyncTransaction, self).__init__(client)
73 BaseTransaction.__init__(self, max_attempts, read_only)
74
75 def _add_write_pbs(self, write_pbs: list) -> None:
76 """Add `Write`` protobufs to this transaction.
77
78 Args:
79 write_pbs (List[google.cloud.firestore_v1.\
80 write.Write]): A list of write protobufs to be added.
81
82 Raises:
83 ValueError: If this transaction is read-only.
84 """
85 if self._read_only:
86 raise ValueError(_WRITE_READ_ONLY)
87
88 super(AsyncTransaction, self)._add_write_pbs(write_pbs)
89
90 async def _begin(self, retry_id: bytes | None = None) -> None:
91 """Begin the transaction.
92
93 Args:
94 retry_id (Optional[bytes]): Transaction ID of a transaction to be
95 retried.
96
97 Raises:
98 ValueError: If the current transaction has already begun.
99 """
100 if self.in_progress:
101 msg = _CANT_BEGIN.format(self._id)
102 raise ValueError(msg)
103
104 transaction_response = await self._client._firestore_api.begin_transaction(
105 request={
106 "database": self._client._database_string,
107 "options": self._options_protobuf(retry_id),
108 },
109 metadata=self._client._rpc_metadata,
110 )
111 self._id = transaction_response.transaction
112
113 async def _rollback(self) -> None:
114 """Roll back the transaction.
115
116 Raises:
117 ValueError: If no transaction is in progress.
118 google.api_core.exceptions.GoogleAPICallError: If the rollback fails.
119 """
120 if not self.in_progress:
121 raise ValueError(_CANT_ROLLBACK)
122
123 try:
124 # NOTE: The response is just ``google.protobuf.Empty``.
125 await self._client._firestore_api.rollback(
126 request={
127 "database": self._client._database_string,
128 "transaction": self._id,
129 },
130 metadata=self._client._rpc_metadata,
131 )
132 finally:
133 # clean up, even if rollback fails
134 self._clean_up()
135
136 async def _commit(self) -> list:
137 """Transactionally commit the changes accumulated.
138
139 Returns:
140 List[:class:`google.cloud.firestore_v1.write.WriteResult`, ...]:
141 The write results corresponding to the changes committed, returned
142 in the same order as the changes were applied to this transaction.
143 A write result contains an ``update_time`` field.
144
145 Raises:
146 ValueError: If no transaction is in progress.
147 """
148 if not self.in_progress:
149 raise ValueError(_CANT_COMMIT)
150
151 commit_response = await self._client._firestore_api.commit(
152 request={
153 "database": self._client._database_string,
154 "writes": self._write_pbs,
155 "transaction": self._id,
156 },
157 metadata=self._client._rpc_metadata,
158 )
159
160 self._clean_up()
161 self.write_results = list(commit_response.write_results)
162 self.commit_time = commit_response.commit_time
163 return self.write_results
164
165 async def get_all(
166 self,
167 references: list,
168 retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
169 timeout: float | None = None,
170 *,
171 read_time: datetime.datetime | None = None,
172 ) -> AsyncGenerator[DocumentSnapshot, Any]:
173 """Retrieves multiple documents from Firestore.
174
175 Args:
176 references (List[.AsyncDocumentReference, ...]): Iterable of document
177 references to be retrieved.
178 retry (google.api_core.retry.Retry): Designation of what errors, if any,
179 should be retried. Defaults to a system-specified policy.
180 timeout (float): The timeout for this request. Defaults to a
181 system-specified value.
182 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
183 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
184 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
185 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
186
187 Yields:
188 .DocumentSnapshot: The next document snapshot that fulfills the
189 query, or :data:`None` if the document does not exist.
190 """
191 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
192 if read_time is not None:
193 kwargs["read_time"] = read_time
194 return await self._client.get_all(references, transaction=self, **kwargs)
195
196 async def get(
197 self,
198 ref_or_query: AsyncDocumentReference | AsyncQuery,
199 retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
200 timeout: Optional[float] = None,
201 *,
202 explain_options: Optional[ExplainOptions] = None,
203 read_time: Optional[datetime.datetime] = None,
204 ) -> AsyncGenerator[DocumentSnapshot, Any] | AsyncStreamGenerator[DocumentSnapshot]:
205 """
206 Retrieve a document or a query result from the database.
207
208 Args:
209 ref_or_query (AsyncDocumentReference | AsyncQuery):
210 The document references or query object to return.
211 retry (google.api_core.retry.Retry): Designation of what errors, if any,
212 should be retried. Defaults to a system-specified policy.
213 timeout (float): The timeout for this request. Defaults to a
214 system-specified value.
215 explain_options
216 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
217 Options to enable query profiling for this query. When set,
218 explain_metrics will be available on the returned generator.
219 Can only be used when running a query, not a document reference.
220 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
221 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
222 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
223 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
224
225 Yields:
226 DocumentSnapshot: The next document snapshot that fulfills the query,
227 or :data:`None` if the document does not exist.
228
229 Raises:
230 ValueError: if `ref_or_query` is not one of the supported types, or
231 explain_options is provided when `ref_or_query` is a document
232 reference.
233 """
234 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
235 if read_time is not None:
236 kwargs["read_time"] = read_time
237 if isinstance(ref_or_query, AsyncDocumentReference):
238 if explain_options is not None:
239 raise ValueError(
240 "When type of `ref_or_query` is `AsyncDocumentReference`, "
241 "`explain_options` cannot be provided."
242 )
243 return await self._client.get_all(
244 [ref_or_query], transaction=self, **kwargs
245 )
246 elif isinstance(ref_or_query, AsyncQuery):
247 if explain_options is not None:
248 kwargs["explain_options"] = explain_options
249 return ref_or_query.stream(transaction=self, **kwargs)
250 else:
251 raise ValueError(
252 'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.'
253 )
254
255
256class _AsyncTransactional(_BaseTransactional):
257 """Provide a callable object to use as a transactional decorater.
258
259 This is surfaced via
260 :func:`~google.cloud.firestore_v1.async_transaction.transactional`.
261
262 Args:
263 to_wrap (Coroutine[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]):
264 A coroutine that should be run (and retried) in a transaction.
265 """
266
267 def __init__(
268 self, to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
269 ) -> None:
270 super(_AsyncTransactional, self).__init__(to_wrap)
271
272 async def _pre_commit(
273 self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs
274 ) -> T:
275 """Begin transaction and call the wrapped coroutine.
276
277 Args:
278 transaction
279 (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
280 A transaction to execute the coroutine within.
281 args (Tuple[Any, ...]): The extra positional arguments to pass
282 along to the wrapped coroutine.
283 kwargs (Dict[str, Any]): The extra keyword arguments to pass
284 along to the wrapped coroutine.
285
286 Returns:
287 T: result of the wrapped coroutine.
288
289 Raises:
290 Exception: Any failure caused by ``to_wrap``.
291 """
292 # Force the ``transaction`` to be not "in progress".
293 transaction._clean_up()
294 await transaction._begin(retry_id=self.retry_id)
295
296 # Update the stored transaction IDs.
297 self.current_id = transaction._id
298 if self.retry_id is None:
299 self.retry_id = self.current_id
300 return await self.to_wrap(transaction, *args, **kwargs)
301
302 async def __call__(
303 self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs
304 ) -> T:
305 """Execute the wrapped callable within a transaction.
306
307 Args:
308 transaction
309 (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
310 A transaction to execute the callable within.
311 args (Tuple[Any, ...]): The extra positional arguments to pass
312 along to the wrapped callable.
313 kwargs (Dict[str, Any]): The extra keyword arguments to pass
314 along to the wrapped callable.
315
316 Returns:
317 T: The result of the wrapped callable.
318
319 Raises:
320 ValueError: If the transaction does not succeed in
321 ``max_attempts``.
322 """
323 self._reset()
324 retryable_exceptions = (
325 (exceptions.Aborted) if not transaction._read_only else ()
326 )
327 last_exc = None
328
329 try:
330 for attempt in range(transaction._max_attempts):
331 result: T = await self._pre_commit(transaction, *args, **kwargs)
332 try:
333 await transaction._commit()
334 return result
335 except retryable_exceptions as exc:
336 last_exc = exc
337 # Retry attempts that result in retryable exceptions
338 # Subsequent requests will use the failed transaction ID as part of
339 # the ``BeginTransactionRequest`` when restarting this transaction
340 # (via ``options.retry_transaction``). This preserves the "spot in
341 # line" of the transaction, so exponential backoff is not required
342 # in this case.
343 # retries exhausted
344 # wrap the last exception in a ValueError before raising
345 msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
346 raise ValueError(msg) from last_exc
347
348 except BaseException:
349 # rollback the transaction on any error
350 # errors raised during _rollback will be chained to the original error through __context__
351 await transaction._rollback()
352 raise
353
354
355def async_transactional(
356 to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
357) -> Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]:
358 """Decorate a callable so that it runs in a transaction.
359
360 Args:
361 to_wrap
362 (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Awaitable[Any]]):
363 A callable that should be run (and retried) in a transaction.
364
365 Returns:
366 Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Awaitable[Any]]:
367 the wrapped callable.
368 """
369 return _AsyncTransactional(to_wrap)