1# Copyright 2020 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for applying Google Cloud Firestore changes in a transaction."""
16from __future__ import annotations
17
18from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Coroutine, Optional
19
20from google.api_core import exceptions, gapic_v1
21from google.api_core import retry_async as retries
22
23from google.cloud.firestore_v1 import _helpers, async_batch
24from google.cloud.firestore_v1.async_document import AsyncDocumentReference
25from google.cloud.firestore_v1.async_query import AsyncQuery
26from google.cloud.firestore_v1.base_transaction import (
27 _CANT_BEGIN,
28 _CANT_COMMIT,
29 _CANT_ROLLBACK,
30 _EXCEED_ATTEMPTS_TEMPLATE,
31 _WRITE_READ_ONLY,
32 MAX_ATTEMPTS,
33 BaseTransaction,
34 _BaseTransactional,
35)
36
37# Types needed only for Type Hints
38if TYPE_CHECKING: # pragma: NO COVER
39 import datetime
40
41 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
42 from google.cloud.firestore_v1.base_document import DocumentSnapshot
43 from google.cloud.firestore_v1.query_profile import ExplainOptions
44
45
46class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction):
47 """Accumulate read-and-write operations to be sent in a transaction.
48
49 Args:
50 client (:class:`~google.cloud.firestore_v1.async_client.AsyncClient`):
51 The client that created this transaction.
52 max_attempts (Optional[int]): The maximum number of attempts for
53 the transaction (i.e. allowing retries). Defaults to
54 :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`.
55 read_only (Optional[bool]): Flag indicating if the transaction
56 should be read-only or should allow writes. Defaults to
57 :data:`False`.
58 """
59
60 def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None:
61 super(AsyncTransaction, self).__init__(client)
62 BaseTransaction.__init__(self, max_attempts, read_only)
63
64 def _add_write_pbs(self, write_pbs: list) -> None:
65 """Add `Write`` protobufs to this transaction.
66
67 Args:
68 write_pbs (List[google.cloud.firestore_v1.\
69 write.Write]): A list of write protobufs to be added.
70
71 Raises:
72 ValueError: If this transaction is read-only.
73 """
74 if self._read_only:
75 raise ValueError(_WRITE_READ_ONLY)
76
77 super(AsyncTransaction, self)._add_write_pbs(write_pbs)
78
79 async def _begin(self, retry_id: bytes | None = None) -> None:
80 """Begin the transaction.
81
82 Args:
83 retry_id (Optional[bytes]): Transaction ID of a transaction to be
84 retried.
85
86 Raises:
87 ValueError: If the current transaction has already begun.
88 """
89 if self.in_progress:
90 msg = _CANT_BEGIN.format(self._id)
91 raise ValueError(msg)
92
93 transaction_response = await self._client._firestore_api.begin_transaction(
94 request={
95 "database": self._client._database_string,
96 "options": self._options_protobuf(retry_id),
97 },
98 metadata=self._client._rpc_metadata,
99 )
100 self._id = transaction_response.transaction
101
102 async def _rollback(self) -> None:
103 """Roll back the transaction.
104
105 Raises:
106 ValueError: If no transaction is in progress.
107 google.api_core.exceptions.GoogleAPICallError: If the rollback fails.
108 """
109 if not self.in_progress:
110 raise ValueError(_CANT_ROLLBACK)
111
112 try:
113 # NOTE: The response is just ``google.protobuf.Empty``.
114 await self._client._firestore_api.rollback(
115 request={
116 "database": self._client._database_string,
117 "transaction": self._id,
118 },
119 metadata=self._client._rpc_metadata,
120 )
121 finally:
122 # clean up, even if rollback fails
123 self._clean_up()
124
125 async def _commit(self) -> list:
126 """Transactionally commit the changes accumulated.
127
128 Returns:
129 List[:class:`google.cloud.firestore_v1.write.WriteResult`, ...]:
130 The write results corresponding to the changes committed, returned
131 in the same order as the changes were applied to this transaction.
132 A write result contains an ``update_time`` field.
133
134 Raises:
135 ValueError: If no transaction is in progress.
136 """
137 if not self.in_progress:
138 raise ValueError(_CANT_COMMIT)
139
140 commit_response = await self._client._firestore_api.commit(
141 request={
142 "database": self._client._database_string,
143 "writes": self._write_pbs,
144 "transaction": self._id,
145 },
146 metadata=self._client._rpc_metadata,
147 )
148
149 self._clean_up()
150 self.write_results = list(commit_response.write_results)
151 self.commit_time = commit_response.commit_time
152 return self.write_results
153
154 async def get_all(
155 self,
156 references: list,
157 retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
158 timeout: float | None = None,
159 *,
160 read_time: datetime.datetime | None = None,
161 ) -> AsyncGenerator[DocumentSnapshot, Any]:
162 """Retrieves multiple documents from Firestore.
163
164 Args:
165 references (List[.AsyncDocumentReference, ...]): Iterable of document
166 references to be retrieved.
167 retry (google.api_core.retry.Retry): Designation of what errors, if any,
168 should be retried. Defaults to a system-specified policy.
169 timeout (float): The timeout for this request. Defaults to a
170 system-specified value.
171 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
172 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
173 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
174 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
175
176 Yields:
177 .DocumentSnapshot: The next document snapshot that fulfills the
178 query, or :data:`None` if the document does not exist.
179 """
180 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
181 if read_time is not None:
182 kwargs["read_time"] = read_time
183 return await self._client.get_all(references, transaction=self, **kwargs)
184
185 async def get(
186 self,
187 ref_or_query: AsyncDocumentReference | AsyncQuery,
188 retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT,
189 timeout: Optional[float] = None,
190 *,
191 explain_options: Optional[ExplainOptions] = None,
192 read_time: Optional[datetime.datetime] = None,
193 ) -> AsyncGenerator[DocumentSnapshot, Any] | AsyncStreamGenerator[DocumentSnapshot]:
194 """
195 Retrieve a document or a query result from the database.
196
197 Args:
198 ref_or_query (AsyncDocumentReference | AsyncQuery):
199 The document references or query object to return.
200 retry (google.api_core.retry.Retry): Designation of what errors, if any,
201 should be retried. Defaults to a system-specified policy.
202 timeout (float): The timeout for this request. Defaults to a
203 system-specified value.
204 explain_options
205 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
206 Options to enable query profiling for this query. When set,
207 explain_metrics will be available on the returned generator.
208 Can only be used when running a query, not a document reference.
209 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
210 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
211 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
212 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
213
214 Yields:
215 DocumentSnapshot: The next document snapshot that fulfills the query,
216 or :data:`None` if the document does not exist.
217
218 Raises:
219 ValueError: if `ref_or_query` is not one of the supported types, or
220 explain_options is provided when `ref_or_query` is a document
221 reference.
222 """
223 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
224 if read_time is not None:
225 kwargs["read_time"] = read_time
226 if isinstance(ref_or_query, AsyncDocumentReference):
227 if explain_options is not None:
228 raise ValueError(
229 "When type of `ref_or_query` is `AsyncDocumentReference`, "
230 "`explain_options` cannot be provided."
231 )
232 return await self._client.get_all(
233 [ref_or_query], transaction=self, **kwargs
234 )
235 elif isinstance(ref_or_query, AsyncQuery):
236 if explain_options is not None:
237 kwargs["explain_options"] = explain_options
238 return ref_or_query.stream(transaction=self, **kwargs)
239 else:
240 raise ValueError(
241 'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.'
242 )
243
244
245class _AsyncTransactional(_BaseTransactional):
246 """Provide a callable object to use as a transactional decorater.
247
248 This is surfaced via
249 :func:`~google.cloud.firestore_v1.async_transaction.transactional`.
250
251 Args:
252 to_wrap (Coroutine[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]):
253 A coroutine that should be run (and retried) in a transaction.
254 """
255
256 def __init__(self, to_wrap) -> None:
257 super(_AsyncTransactional, self).__init__(to_wrap)
258
259 async def _pre_commit(
260 self, transaction: AsyncTransaction, *args, **kwargs
261 ) -> Coroutine:
262 """Begin transaction and call the wrapped coroutine.
263
264 Args:
265 transaction
266 (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`):
267 A transaction to execute the coroutine within.
268 args (Tuple[Any, ...]): The extra positional arguments to pass
269 along to the wrapped coroutine.
270 kwargs (Dict[str, Any]): The extra keyword arguments to pass
271 along to the wrapped coroutine.
272
273 Returns:
274 Any: result of the wrapped coroutine.
275
276 Raises:
277 Exception: Any failure caused by ``to_wrap``.
278 """
279 # Force the ``transaction`` to be not "in progress".
280 transaction._clean_up()
281 await transaction._begin(retry_id=self.retry_id)
282
283 # Update the stored transaction IDs.
284 self.current_id = transaction._id
285 if self.retry_id is None:
286 self.retry_id = self.current_id
287 return await self.to_wrap(transaction, *args, **kwargs)
288
289 async def __call__(self, transaction, *args, **kwargs):
290 """Execute the wrapped callable within a transaction.
291
292 Args:
293 transaction
294 (:class:`~google.cloud.firestore_v1.transaction.Transaction`):
295 A transaction to execute the callable within.
296 args (Tuple[Any, ...]): The extra positional arguments to pass
297 along to the wrapped callable.
298 kwargs (Dict[str, Any]): The extra keyword arguments to pass
299 along to the wrapped callable.
300
301 Returns:
302 Any: The result of the wrapped callable.
303
304 Raises:
305 ValueError: If the transaction does not succeed in
306 ``max_attempts``.
307 """
308 self._reset()
309 retryable_exceptions = (
310 (exceptions.Aborted) if not transaction._read_only else ()
311 )
312 last_exc = None
313
314 try:
315 for attempt in range(transaction._max_attempts):
316 result = await self._pre_commit(transaction, *args, **kwargs)
317 try:
318 await transaction._commit()
319 return result
320 except retryable_exceptions as exc:
321 last_exc = exc
322 # Retry attempts that result in retryable exceptions
323 # Subsequent requests will use the failed transaction ID as part of
324 # the ``BeginTransactionRequest`` when restarting this transaction
325 # (via ``options.retry_transaction``). This preserves the "spot in
326 # line" of the transaction, so exponential backoff is not required
327 # in this case.
328 # retries exhausted
329 # wrap the last exception in a ValueError before raising
330 msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts)
331 raise ValueError(msg) from last_exc
332
333 except BaseException:
334 # rollback the transaction on any error
335 # errors raised during _rollback will be chained to the original error through __context__
336 await transaction._rollback()
337 raise
338
339
340def async_transactional(
341 to_wrap: Callable[[AsyncTransaction], Any]
342) -> _AsyncTransactional:
343 """Decorate a callable so that it runs in a transaction.
344
345 Args:
346 to_wrap
347 (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]):
348 A callable that should be run (and retried) in a transaction.
349
350 Returns:
351 Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]:
352 the wrapped callable.
353 """
354 return _AsyncTransactional(to_wrap)