1# Copyright 2020 Google LLC All rights reserved. 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Helpers for applying Google Cloud Firestore changes in a transaction.""" 
    16from __future__ import annotations 
    17 
    18from typing import ( 
    19    TYPE_CHECKING, 
    20    Any, 
    21    AsyncGenerator, 
    22    Awaitable, 
    23    Callable, 
    24    Generic, 
    25    Optional, 
    26) 
    27from typing_extensions import Concatenate, ParamSpec, TypeVar 
    28 
    29from google.api_core import exceptions, gapic_v1 
    30from google.api_core import retry_async as retries 
    31 
    32from google.cloud.firestore_v1 import _helpers, async_batch 
    33from google.cloud.firestore_v1.async_document import AsyncDocumentReference 
    34from google.cloud.firestore_v1.async_query import AsyncQuery 
    35from google.cloud.firestore_v1.base_transaction import ( 
    36    _CANT_BEGIN, 
    37    _CANT_COMMIT, 
    38    _CANT_ROLLBACK, 
    39    _EXCEED_ATTEMPTS_TEMPLATE, 
    40    _WRITE_READ_ONLY, 
    41    MAX_ATTEMPTS, 
    42    BaseTransaction, 
    43    _BaseTransactional, 
    44) 
    45 
    46# Types needed only for Type Hints 
    47if TYPE_CHECKING:  # pragma: NO COVER 
    48    import datetime 
    49 
    50    from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator 
    51    from google.cloud.firestore_v1.base_document import DocumentSnapshot 
    52    from google.cloud.firestore_v1.query_profile import ExplainOptions 
    53 
    54 
    55T = TypeVar("T") 
    56P = ParamSpec("P") 
    57 
    58 
    59class AsyncTransaction(async_batch.AsyncWriteBatch, BaseTransaction): 
    60    """Accumulate read-and-write operations to be sent in a transaction. 
    61 
    62    Args: 
    63        client (:class:`~google.cloud.firestore_v1.async_client.AsyncClient`): 
    64            The client that created this transaction. 
    65        max_attempts (Optional[int]): The maximum number of attempts for 
    66            the transaction (i.e. allowing retries). Defaults to 
    67            :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`. 
    68        read_only (Optional[bool]): Flag indicating if the transaction 
    69            should be read-only or should allow writes. Defaults to 
    70            :data:`False`. 
    71    """ 
    72 
    73    def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: 
    74        super(AsyncTransaction, self).__init__(client) 
    75        BaseTransaction.__init__(self, max_attempts, read_only) 
    76 
    77    def _add_write_pbs(self, write_pbs: list) -> None: 
    78        """Add `Write`` protobufs to this transaction. 
    79 
    80        Args: 
    81            write_pbs (List[google.cloud.firestore_v1.\ 
    82                write.Write]): A list of write protobufs to be added. 
    83 
    84        Raises: 
    85            ValueError: If this transaction is read-only. 
    86        """ 
    87        if self._read_only: 
    88            raise ValueError(_WRITE_READ_ONLY) 
    89 
    90        super(AsyncTransaction, self)._add_write_pbs(write_pbs) 
    91 
    92    async def _begin(self, retry_id: bytes | None = None) -> None: 
    93        """Begin the transaction. 
    94 
    95        Args: 
    96            retry_id (Optional[bytes]): Transaction ID of a transaction to be 
    97                retried. 
    98 
    99        Raises: 
    100            ValueError: If the current transaction has already begun. 
    101        """ 
    102        if self.in_progress: 
    103            msg = _CANT_BEGIN.format(self._id) 
    104            raise ValueError(msg) 
    105 
    106        transaction_response = await self._client._firestore_api.begin_transaction( 
    107            request={ 
    108                "database": self._client._database_string, 
    109                "options": self._options_protobuf(retry_id), 
    110            }, 
    111            metadata=self._client._rpc_metadata, 
    112        ) 
    113        self._id = transaction_response.transaction 
    114 
    115    async def _rollback(self) -> None: 
    116        """Roll back the transaction. 
    117 
    118        Raises: 
    119            ValueError: If no transaction is in progress. 
    120            google.api_core.exceptions.GoogleAPICallError: If the rollback fails. 
    121        """ 
    122        if not self.in_progress: 
    123            raise ValueError(_CANT_ROLLBACK) 
    124 
    125        try: 
    126            # NOTE: The response is just ``google.protobuf.Empty``. 
    127            await self._client._firestore_api.rollback( 
    128                request={ 
    129                    "database": self._client._database_string, 
    130                    "transaction": self._id, 
    131                }, 
    132                metadata=self._client._rpc_metadata, 
    133            ) 
    134        finally: 
    135            # clean up, even if rollback fails 
    136            self._clean_up() 
    137 
    138    async def _commit(self) -> list: 
    139        """Transactionally commit the changes accumulated. 
    140 
    141        Returns: 
    142            List[:class:`google.cloud.firestore_v1.write.WriteResult`, ...]: 
    143            The write results corresponding to the changes committed, returned 
    144            in the same order as the changes were applied to this transaction. 
    145            A write result contains an ``update_time`` field. 
    146 
    147        Raises: 
    148            ValueError: If no transaction is in progress. 
    149        """ 
    150        if not self.in_progress: 
    151            raise ValueError(_CANT_COMMIT) 
    152 
    153        commit_response = await self._client._firestore_api.commit( 
    154            request={ 
    155                "database": self._client._database_string, 
    156                "writes": self._write_pbs, 
    157                "transaction": self._id, 
    158            }, 
    159            metadata=self._client._rpc_metadata, 
    160        ) 
    161 
    162        self._clean_up() 
    163        self.write_results = list(commit_response.write_results) 
    164        self.commit_time = commit_response.commit_time 
    165        return self.write_results 
    166 
    167    async def get_all( 
    168        self, 
    169        references: list, 
    170        retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, 
    171        timeout: float | None = None, 
    172        *, 
    173        read_time: datetime.datetime | None = None, 
    174    ) -> AsyncGenerator[DocumentSnapshot, Any]: 
    175        """Retrieves multiple documents from Firestore. 
    176 
    177        Args: 
    178            references (List[.AsyncDocumentReference, ...]): Iterable of document 
    179                references to be retrieved. 
    180            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    181                should be retried.  Defaults to a system-specified policy. 
    182            timeout (float): The timeout for this request.  Defaults to a 
    183                system-specified value. 
    184            read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 
    185                time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 
    186                is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 
    187                timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 
    188 
    189        Yields: 
    190            .DocumentSnapshot: The next document snapshot that fulfills the 
    191            query, or :data:`None` if the document does not exist. 
    192        """ 
    193        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    194        if read_time is not None: 
    195            kwargs["read_time"] = read_time 
    196        return await self._client.get_all(references, transaction=self, **kwargs) 
    197 
    198    async def get( 
    199        self, 
    200        ref_or_query: AsyncDocumentReference | AsyncQuery, 
    201        retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, 
    202        timeout: Optional[float] = None, 
    203        *, 
    204        explain_options: Optional[ExplainOptions] = None, 
    205        read_time: Optional[datetime.datetime] = None, 
    206    ) -> AsyncGenerator[DocumentSnapshot, Any] | AsyncStreamGenerator[DocumentSnapshot]: 
    207        """ 
    208        Retrieve a document or a query result from the database. 
    209 
    210        Args: 
    211            ref_or_query (AsyncDocumentReference | AsyncQuery): 
    212                The document references or query object to return. 
    213            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    214                should be retried.  Defaults to a system-specified policy. 
    215            timeout (float): The timeout for this request.  Defaults to a 
    216                system-specified value. 
    217            explain_options 
    218                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    219                Options to enable query profiling for this query. When set, 
    220                explain_metrics will be available on the returned generator. 
    221                Can only be used when running a query, not a document reference. 
    222            read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 
    223                time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 
    224                is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 
    225                timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 
    226 
    227        Yields: 
    228            DocumentSnapshot: The next document snapshot that fulfills the query, 
    229            or :data:`None` if the document does not exist. 
    230 
    231        Raises: 
    232            ValueError: if `ref_or_query` is not one of the supported types, or 
    233            explain_options is provided when `ref_or_query` is a document 
    234            reference. 
    235        """ 
    236        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    237        if read_time is not None: 
    238            kwargs["read_time"] = read_time 
    239        if isinstance(ref_or_query, AsyncDocumentReference): 
    240            if explain_options is not None: 
    241                raise ValueError( 
    242                    "When type of `ref_or_query` is `AsyncDocumentReference`, " 
    243                    "`explain_options` cannot be provided." 
    244                ) 
    245            return await self._client.get_all( 
    246                [ref_or_query], transaction=self, **kwargs 
    247            ) 
    248        elif isinstance(ref_or_query, AsyncQuery): 
    249            if explain_options is not None: 
    250                kwargs["explain_options"] = explain_options 
    251            return ref_or_query.stream(transaction=self, **kwargs) 
    252        else: 
    253            raise ValueError( 
    254                'Value for argument "ref_or_query" must be a AsyncDocumentReference or a AsyncQuery.' 
    255            ) 
    256 
    257 
    258class _AsyncTransactional(_BaseTransactional, Generic[T, P]): 
    259    """Provide a callable object to use as a transactional decorater. 
    260 
    261    This is surfaced via 
    262    :func:`~google.cloud.firestore_v1.async_transaction.transactional`. 
    263 
    264    Args: 
    265        to_wrap (Coroutine[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Any]): 
    266            A coroutine that should be run (and retried) in a transaction. 
    267    """ 
    268 
    269    def __init__( 
    270        self, to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]] 
    271    ) -> None: 
    272        super(_AsyncTransactional, self).__init__(to_wrap) 
    273 
    274    async def _pre_commit( 
    275        self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs 
    276    ) -> T: 
    277        """Begin transaction and call the wrapped coroutine. 
    278 
    279        Args: 
    280            transaction 
    281                (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`): 
    282                A transaction to execute the coroutine within. 
    283            args (Tuple[Any, ...]): The extra positional arguments to pass 
    284                along to the wrapped coroutine. 
    285            kwargs (Dict[str, Any]): The extra keyword arguments to pass 
    286                along to the wrapped coroutine. 
    287 
    288        Returns: 
    289            T: result of the wrapped coroutine. 
    290 
    291        Raises: 
    292            Exception: Any failure caused by ``to_wrap``. 
    293        """ 
    294        # Force the ``transaction`` to be not "in progress". 
    295        transaction._clean_up() 
    296        await transaction._begin(retry_id=self.retry_id) 
    297 
    298        # Update the stored transaction IDs. 
    299        self.current_id = transaction._id 
    300        if self.retry_id is None: 
    301            self.retry_id = self.current_id 
    302        return await self.to_wrap(transaction, *args, **kwargs) 
    303 
    304    async def __call__( 
    305        self, transaction: AsyncTransaction, *args: P.args, **kwargs: P.kwargs 
    306    ) -> T: 
    307        """Execute the wrapped callable within a transaction. 
    308 
    309        Args: 
    310            transaction 
    311                (:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`): 
    312                A transaction to execute the callable within. 
    313            args (Tuple[Any, ...]): The extra positional arguments to pass 
    314                along to the wrapped callable. 
    315            kwargs (Dict[str, Any]): The extra keyword arguments to pass 
    316                along to the wrapped callable. 
    317 
    318        Returns: 
    319            T: The result of the wrapped callable. 
    320 
    321        Raises: 
    322            ValueError: If the transaction does not succeed in 
    323                ``max_attempts``. 
    324        """ 
    325        self._reset() 
    326        retryable_exceptions = ( 
    327            (exceptions.Aborted) if not transaction._read_only else () 
    328        ) 
    329        last_exc = None 
    330 
    331        try: 
    332            for attempt in range(transaction._max_attempts): 
    333                result: T = await self._pre_commit(transaction, *args, **kwargs) 
    334                try: 
    335                    await transaction._commit() 
    336                    return result 
    337                except retryable_exceptions as exc: 
    338                    last_exc = exc 
    339                # Retry attempts that result in retryable exceptions 
    340                # Subsequent requests will use the failed transaction ID as part of 
    341                # the ``BeginTransactionRequest`` when restarting this transaction 
    342                # (via ``options.retry_transaction``). This preserves the "spot in 
    343                # line" of the transaction, so exponential backoff is not required 
    344                # in this case. 
    345            # retries exhausted 
    346            # wrap the last exception in a ValueError before raising 
    347            msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) 
    348            raise ValueError(msg) from last_exc 
    349 
    350        except BaseException: 
    351            # rollback the transaction on any error 
    352            # errors raised during _rollback will be chained to the original error through __context__ 
    353            await transaction._rollback() 
    354            raise 
    355 
    356 
    357def async_transactional( 
    358    to_wrap: Callable[Concatenate[AsyncTransaction, P], Awaitable[T]] 
    359) -> Callable[Concatenate[AsyncTransaction, P], Awaitable[T]]: 
    360    """Decorate a callable so that it runs in a transaction. 
    361 
    362    Args: 
    363        to_wrap 
    364            (Callable[[:class:`~google.cloud.firestore_v1.async_transaction.AsyncTransaction`, ...], Awaitable[Any]]): 
    365            A callable that should be run (and retried) in a transaction. 
    366 
    367    Returns: 
    368        Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Awaitable[Any]]: 
    369        the wrapped callable. 
    370    """ 
    371    return _AsyncTransactional(to_wrap)