1# Copyright 2020 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for applying Google Cloud Firestore changes in a transaction."""
16from __future__ import annotations
17
18from typing import (
19 TYPE_CHECKING,
20 Any,
21 AsyncGenerator,
22 Awaitable,
23 Callable,
24 Generic,
25 Optional,
26)
27from typing_extensions import Concatenate, ParamSpec, TypeVar
28
29from google.api_core import exceptions, gapic_v1
30from google.api_core import retry_async as retries
31
32from google.cloud.firestore_v1 import _helpers, async_batch
33from google.cloud.firestore_v1.async_document import AsyncDocumentReference
34from google.cloud.firestore_v1.async_query import AsyncQuery
35from google.cloud.firestore_v1.base_transaction import (
36 _CANT_BEGIN,
37 _CANT_COMMIT,
38 _CANT_ROLLBACK,
39 _EXCEED_ATTEMPTS_TEMPLATE,
40 _WRITE_READ_ONLY,
41 MAX_ATTEMPTS,
42 BaseTransaction,
43 _BaseTransactional,
44)
45
46# Types needed only for Type Hints
47if TYPE_CHECKING: # pragma: NO COVER
48 import datetime
49
50 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
51 from google.cloud.firestore_v1.base_document import DocumentSnapshot
52 from google.cloud.firestore_v1.query_profile import ExplainOptions
53
54
55T = TypeVar("T")
56P = ParamSpec("P")
57
58
59class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
60 """Accumulate read-and-write operations to be sent in a transaction.
61
62 Args:
63 client (:class:`~google.cloud.firestore_v1.async_client.AsyncClient`):
64 The client that created this transaction.
65 max_attempts (Optional[int]): The maximum number of attempts for
66 the transaction (i.e. allowing retries). Defaults to
67 :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`.
68 read_only (Optional[bool]): Flag indicating if the transaction
69 should be read-only or should allow writes. Defaults to
70 :data:`False`.
71 """
72
73 def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None:
74 super(AsyncTransaction, self).__init__(client)
75 BaseTransaction.__init__(self, max_attempts, read_only)
76
77 def _add_write_pbs(self, write_pbs: list) -> None:
78 """Add `Write`` protobufs to this transaction.
79
80 Args:
81 write_pbs (List[google.cloud.firestore_v1.\
82 write.Write]): A list of write protobufs to be added.
83
84 Raises:
85 ValueError: If this transaction is read-only.
86 """
87 if self._read_only:
88 raise ValueError(_WRITE_READ_ONLY)
89
90 super(AsyncTransaction, self)._add_write_pbs(write_pbs)
91
92 async def _begin(self, retry_id: bytes | None = None) -> None:
93 """Begin the transaction.
94
95 Args:
96 retry_id (Optional[bytes]): Transaction ID of a transaction to be
97 retried.
98
99 Raises:
100 ValueError: If the current transaction has already begun.
101 """
102 if self.in_progress:
103 msg = _CANT_BEGIN.format(self._id)
104 raise ValueError(msg)
105
106 transaction_response = await self._client._firestore_api.begin_transaction(
107 request={
108 "database": self._client._database_string,
109 "options": self._options_protobuf(retry_id),
110 },
111 metadata=self._client._rpc_metadata,
112 )
113 self._id = transaction_response.transaction
114
115 async def _rollback(self) -> None:
116 """Roll back the transaction.
117
118 Raises:
119 ValueError: If no transaction is in progress.
120 google.api_core.exceptions.GoogleAPICallError: If the rollback fails.
121 """
122 if not self.in_progress:
123 raise ValueError(_CANT_ROLLBACK)
124
125 try:
126 # NOTE: The response is just ``google.protobuf.Empty``.
127 await self._client._firestore_api.rollback(
128 request={
129 "database": self._client._database_string,
130 "transaction": self._id,
131 },
132 metadata=self._client._rpc_metadata,
133 )
134 finally:
135 # clean up, even if rollback fails
136 self._clean_up()
137
138 async def _commit(self) -> list:
139 """Transactionally commit the changes accumulated.
140
141 Returns:
142 List[:class:`google.cloud.firestore_v1.write.WriteResult`, ...]:
143 The write results corresponding to the changes committed, returned
144 in the same order as the changes were applied to this transaction.
145 A write result contains an ``update_time`` field.
146
147 Raises:
148 ValueError: If no transaction is in progress.
149 """
150 if not self.in_progress:
151 raise ValueError(_CANT_COMMIT)
152
153 commit_response = await self._client._firestore_api.commit(
154 request={
155 "database": self._client._database_string,
156 "writes": self._write_pbs,
157 "transaction": self._id,
158 },
159 metadata=self._client._rpc_metadata,
160 )
161
162 self._clean_up()
163 self.write_results = list(commit_response.write_results)
164 self.commit_time = commit_response.commit_time
165 return self.write_results
166
167 async def get_all(
168 self,
169 references: list,
170 retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
171 timeout: float | None = None,
172 *,
173 read_time: datetime.datetime | None = None,
174 ) -> AsyncGenerator[DocumentSnapshot, Any]:
175 """Retrieves multiple documents from Firestore.
176
177 Args:
178 references (List[.AsyncDocumentReference, ...]): Iterable of document
179 references to be retrieved.
180 retry (google.api_core.retry.Retry): Designation of what errors, if any,
181 should be retried. Defaults to a system-specified policy.
182 timeout (float): The timeout for this request. Defaults to a
183 system-specified value.
184 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
185 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
186 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
187 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
188
189 Yields:
190 .DocumentSnapshot: The next document snapshot that fulfills the
191 query, or :data:`None` if the document does not exist.
192 """
193 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
194 if read_time is not None:
195 kwargs["read_time"] = read_time
196 return await self._client.get_all(references, transaction=self, **kwargs)
197
198 async def get(
199 self,
200 ref_or_query: AsyncDocumentReference | AsyncQuery,
201 retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
202 timeout: Optional[float] = None,
203 *,
204 explain_options: Optional[ExplainOptions] = None,
205 read_time: Optional[datetime.datetime] = None,
206 ) -> AsyncGenerator[DocumentSnapshot, Any] | AsyncStreamGenerator[DocumentSnapshot]:
207 """
208 Retrieve a document or a query result from the database.
209
210 Args:
211 ref_or_query (AsyncDocumentReference | AsyncQuery):
212 The document references or query object to return.
213 retry (google.api_core.retry.Retry): Designation of what errors, if any,
214 should be retried. Defaults to a system-specified policy.
215 timeout (float): The timeout for this request. Defaults to a
216 system-specified value.
217 explain_options
218 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
219 Options to enable query profiling for this query. When set,
220 explain_metrics will be available on the returned generator.
221 Can only be used when running a query, not a document reference.
222 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
223 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
224 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
225 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
226
227 Yields:
228 DocumentSnapshot: The next document snapshot that fulfills the query,
229 or :data:`None` if the document does not exist.
230
231 Raises:
232 ValueError: if `ref_or_query` is not one of the supported types, or
233 explain_options is provided when `ref_or_query` is a document
234 reference.
235 """
236 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
237 if read_time is not None:
238 kwargs["read_time"] = read_time
239 if isinstance(ref_or_query, AsyncDocumentReference):
240 if explain_options is not None:
241 raise ValueError(
242 "When type of `ref_or_query` is `AsyncDocumentReference`, "
243 "`explain_options` cannot be provided."
244 )
245 return await self._client.get_all(
246 [ref_or_query], transaction=self, **kwargs
247 )
248 elif isinstance(ref_or_query, AsyncQuery):
249 if explain_options is not None:
250 kwargs["explain_options"] = explain_options
251 return ref_or_query.stream(transaction=self, **kwargs)
252 else:
253 raise ValueError(
254 'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.'
255 )
256
257
258class _AsyncTransactional(_BaseTransactional, Generic[T, P]):
259 """Provide a callable object to use as a transactional decorater.
260
261 This is surfaced via
262 :func:`~google.cloud.firestore_v1.async_transaction.transactional`.
263
264 Args:
265 to_wrap (Coroutine[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]):
266 A coroutine that should be run (and retried) in a transaction.
267 """
268
269 def __init__(
270 self, to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
271 ) -> None:
272 super(_AsyncTransactional, self).__init__(to_wrap)
273
274 async def _pre_commit(
275 self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs
276 ) -> T:
277 """Begin transaction and call the wrapped coroutine.
278
279 Args:
280 transaction
281 (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
282 A transaction to execute the coroutine within.
283 args (Tuple[Any, ...]): The extra positional arguments to pass
284 along to the wrapped coroutine.
285 kwargs (Dict[str, Any]): The extra keyword arguments to pass
286 along to the wrapped coroutine.
287
288 Returns:
289 T: result of the wrapped coroutine.
290
291 Raises:
292 Exception: Any failure caused by ``to_wrap``.
293 """
294 # Force the ``transaction`` to be not "in progress".
295 transaction._clean_up()
296 await transaction._begin(retry_id=self.retry_id)
297
298 # Update the stored transaction IDs.
299 self.current_id = transaction._id
300 if self.retry_id is None:
301 self.retry_id = self.current_id
302 return await self.to_wrap(transaction, *args, **kwargs)
303
304 async def __call__(
305 self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs
306 ) -> T:
307 """Execute the wrapped callable within a transaction.
308
309 Args:
310 transaction
311 (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
312 A transaction to execute the callable within.
313 args (Tuple[Any, ...]): The extra positional arguments to pass
314 along to the wrapped callable.
315 kwargs (Dict[str, Any]): The extra keyword arguments to pass
316 along to the wrapped callable.
317
318 Returns:
319 T: The result of the wrapped callable.
320
321 Raises:
322 ValueError: If the transaction does not succeed in
323 ``max_attempts``.
324 """
325 self._reset()
326 retryable_exceptions = (
327 (exceptions.Aborted) if not transaction._read_only else ()
328 )
329 last_exc = None
330
331 try:
332 for attempt in range(transaction._max_attempts):
333 result: T = await self._pre_commit(transaction, *args, **kwargs)
334 try:
335 await transaction._commit()
336 return result
337 except retryable_exceptions as exc:
338 last_exc = exc
339 # Retry attempts that result in retryable exceptions
340 # Subsequent requests will use the failed transaction ID as part of
341 # the ``BeginTransactionRequest`` when restarting this transaction
342 # (via ``options.retry_transaction``). This preserves the "spot in
343 # line" of the transaction, so exponential backoff is not required
344 # in this case.
345 # retries exhausted
346 # wrap the last exception in a ValueError before raising
347 msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
348 raise ValueError(msg) from last_exc
349
350 except BaseException:
351 # rollback the transaction on any error
352 # errors raised during _rollback will be chained to the original error through __context__
353 await transaction._rollback()
354 raise
355
356
357def async_transactional(
358 to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]
359) -> Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]:
360 """Decorate a callable so that it runs in a transaction.
361
362 Args:
363 to_wrap
364 (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Awaitable[Any]]):
365 A callable that should be run (and retried) in a transaction.
366
367 Returns:
368 Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Awaitable[Any]]:
369 the wrapped callable.
370 """
371 return _AsyncTransactional(to_wrap)