1# Copyright 2017 Google LLC All rights reserved. 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Helpers for applying Google Cloud Firestore changes in a transaction.""" 
    16from __future__ import annotations 
    17 
    18from typing import TYPE_CHECKING, Any, Callable, Generator, Optional 
    19 
    20from google.api_core import exceptions, gapic_v1 
    21from google.api_core import retry as retries 
    22 
    23from google.cloud.firestore_v1 import _helpers, batch 
    24from google.cloud.firestore_v1.base_transaction import ( 
    25    _CANT_BEGIN, 
    26    _CANT_COMMIT, 
    27    _CANT_ROLLBACK, 
    28    _EXCEED_ATTEMPTS_TEMPLATE, 
    29    _WRITE_READ_ONLY, 
    30    MAX_ATTEMPTS, 
    31    BaseTransaction, 
    32    _BaseTransactional, 
    33) 
    34from google.cloud.firestore_v1.document import DocumentReference 
    35from google.cloud.firestore_v1.query import Query 
    36 
    37# Types needed only for Type Hints 
    38if TYPE_CHECKING:  # pragma: NO COVER 
    39    from google.cloud.firestore_v1.base_document import DocumentSnapshot 
    40    from google.cloud.firestore_v1.query_profile import ExplainOptions 
    41    from google.cloud.firestore_v1.stream_generator import StreamGenerator 
    42 
    43    import datetime 
    44 
    45 
    46class Transaction(batch.WriteBatch, BaseTransaction): 
    47    """Accumulate read-and-write operations to be sent in a transaction. 
    48 
    49    Args: 
    50        client (:class:`~google.cloud.firestore_v1.client.Client`): 
    51            The client that created this transaction. 
    52        max_attempts (Optional[int]): The maximum number of attempts for 
    53            the transaction (i.e. allowing retries). Defaults to 
    54            :attr:`~google.cloud.firestore_v1.transaction.MAX_ATTEMPTS`. 
    55        read_only (Optional[bool]): Flag indicating if the transaction 
    56            should be read-only or should allow writes. Defaults to 
    57            :data:`False`. 
    58    """ 
    59 
    60    def __init__(self, client, max_attempts=MAX_ATTEMPTS, read_only=False) -> None: 
    61        super(Transaction, self).__init__(client) 
    62        BaseTransaction.__init__(self, max_attempts, read_only) 
    63 
    64    def _add_write_pbs(self, write_pbs: list) -> None: 
    65        """Add `Write`` protobufs to this transaction. 
    66 
    67        Args: 
    68            write_pbs (List[google.cloud.firestore_v1.\ 
    69                write.Write]): A list of write protobufs to be added. 
    70 
    71        Raises: 
    72            ValueError: If this transaction is read-only. 
    73        """ 
    74        if self._read_only: 
    75            raise ValueError(_WRITE_READ_ONLY) 
    76 
    77        super(Transaction, self)._add_write_pbs(write_pbs) 
    78 
    79    def _begin(self, retry_id: bytes | None = None) -> None: 
    80        """Begin the transaction. 
    81 
    82        Args: 
    83            retry_id (Optional[bytes]): Transaction ID of a transaction to be 
    84                retried. 
    85 
    86        Raises: 
    87            ValueError: If the current transaction has already begun. 
    88        """ 
    89        if self.in_progress: 
    90            msg = _CANT_BEGIN.format(self._id) 
    91            raise ValueError(msg) 
    92 
    93        transaction_response = self._client._firestore_api.begin_transaction( 
    94            request={ 
    95                "database": self._client._database_string, 
    96                "options": self._options_protobuf(retry_id), 
    97            }, 
    98            metadata=self._client._rpc_metadata, 
    99        ) 
    100        self._id = transaction_response.transaction 
    101 
    102    def _rollback(self) -> None: 
    103        """Roll back the transaction. 
    104 
    105        Raises: 
    106            ValueError: If no transaction is in progress. 
    107            google.api_core.exceptions.GoogleAPICallError: If the rollback fails. 
    108        """ 
    109        if not self.in_progress: 
    110            raise ValueError(_CANT_ROLLBACK) 
    111 
    112        try: 
    113            # NOTE: The response is just ``google.protobuf.Empty``. 
    114            self._client._firestore_api.rollback( 
    115                request={ 
    116                    "database": self._client._database_string, 
    117                    "transaction": self._id, 
    118                }, 
    119                metadata=self._client._rpc_metadata, 
    120            ) 
    121        finally: 
    122            # clean up, even if rollback fails 
    123            self._clean_up() 
    124 
    125    def _commit(self) -> list: 
    126        """Transactionally commit the changes accumulated. 
    127 
    128        Returns: 
    129            List[:class:`google.cloud.firestore_v1.write.WriteResult`, ...]: 
    130            The write results corresponding to the changes committed, returned 
    131            in the same order as the changes were applied to this transaction. 
    132            A write result contains an ``update_time`` field. 
    133 
    134        Raises: 
    135            ValueError: If no transaction is in progress. 
    136        """ 
    137        if not self.in_progress: 
    138            raise ValueError(_CANT_COMMIT) 
    139 
    140        commit_response = self._client._firestore_api.commit( 
    141            request={ 
    142                "database": self._client._database_string, 
    143                "writes": self._write_pbs, 
    144                "transaction": self._id, 
    145            }, 
    146            metadata=self._client._rpc_metadata, 
    147        ) 
    148 
    149        self._clean_up() 
    150        self.write_results = list(commit_response.write_results) 
    151        self.commit_time = commit_response.commit_time 
    152        return self.write_results 
    153 
    154    def get_all( 
    155        self, 
    156        references: list, 
    157        retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, 
    158        timeout: float | None = None, 
    159        *, 
    160        read_time: datetime.datetime | None = None, 
    161    ) -> Generator[DocumentSnapshot, Any, None]: 
    162        """Retrieves multiple documents from Firestore. 
    163 
    164        Args: 
    165            references (List[.DocumentReference, ...]): Iterable of document 
    166                references to be retrieved. 
    167            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    168                should be retried.  Defaults to a system-specified policy. 
    169            timeout (float): The timeout for this request.  Defaults to a 
    170                system-specified value. 
    171            read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 
    172                time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 
    173                is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 
    174                timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 
    175 
    176        Yields: 
    177            .DocumentSnapshot: The next document snapshot that fulfills the 
    178            query, or :data:`None` if the document does not exist. 
    179        """ 
    180        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    181        if read_time is not None: 
    182            kwargs["read_time"] = read_time 
    183        return self._client.get_all(references, transaction=self, **kwargs) 
    184 
    185    def get( 
    186        self, 
    187        ref_or_query: DocumentReference | Query, 
    188        retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, 
    189        timeout: Optional[float] = None, 
    190        *, 
    191        explain_options: Optional[ExplainOptions] = None, 
    192        read_time: Optional[datetime.datetime] = None, 
    193    ) -> StreamGenerator[DocumentSnapshot] | Generator[DocumentSnapshot, Any, None]: 
    194        """Retrieve a document or a query result from the database. 
    195 
    196        Args: 
    197            ref_or_query (DocumentReference | Query): 
    198                The document references or query object to return. 
    199            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    200                should be retried.  Defaults to a system-specified policy. 
    201            timeout (float): The timeout for this request.  Defaults to a 
    202                system-specified value. 
    203            explain_options 
    204                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    205                Options to enable query profiling for this query. When set, 
    206                explain_metrics will be available on the returned generator. 
    207                Can only be used when running a query, not a document reference. 
    208            read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 
    209                time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 
    210                is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 
    211                timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 
    212 
    213        Yields: 
    214            .DocumentSnapshot: The next document snapshot that fulfills the 
    215            query, or :data:`None` if the document does not exist. 
    216 
    217        Raises: 
    218            ValueError: if `ref_or_query` is not one of the supported types, or 
    219            explain_options is provided when `ref_or_query` is a document 
    220            reference. 
    221        """ 
    222        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    223        if read_time is not None: 
    224            kwargs["read_time"] = read_time 
    225        if isinstance(ref_or_query, DocumentReference): 
    226            if explain_options is not None: 
    227                raise ValueError( 
    228                    "When type of `ref_or_query` is `AsyncDocumentReference`, " 
    229                    "`explain_options` cannot be provided." 
    230                ) 
    231            return self._client.get_all([ref_or_query], transaction=self, **kwargs) 
    232        elif isinstance(ref_or_query, Query): 
    233            if explain_options is not None: 
    234                kwargs["explain_options"] = explain_options 
    235            return ref_or_query.stream(transaction=self, **kwargs) 
    236        else: 
    237            raise ValueError( 
    238                'Value for argument "ref_or_query" must be a DocumentReference or a Query.' 
    239            ) 
    240 
    241 
    242class _Transactional(_BaseTransactional): 
    243    """Provide a callable object to use as a transactional decorater. 
    244 
    245    This is surfaced via 
    246    :func:`~google.cloud.firestore_v1.transaction.transactional`. 
    247 
    248    Args: 
    249        to_wrap (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): 
    250            A callable that should be run (and retried) in a transaction. 
    251    """ 
    252 
    253    def __init__(self, to_wrap) -> None: 
    254        super(_Transactional, self).__init__(to_wrap) 
    255 
    256    def _pre_commit(self, transaction: Transaction, *args, **kwargs) -> Any: 
    257        """Begin transaction and call the wrapped callable. 
    258 
    259        Args: 
    260            transaction 
    261                (:class:`~google.cloud.firestore_v1.transaction.Transaction`): 
    262                A transaction to execute the callable within. 
    263            args (Tuple[Any, ...]): The extra positional arguments to pass 
    264                along to the wrapped callable. 
    265            kwargs (Dict[str, Any]): The extra keyword arguments to pass 
    266                along to the wrapped callable. 
    267 
    268        Returns: 
    269            Any: result of the wrapped callable. 
    270 
    271        Raises: 
    272            Exception: Any failure caused by ``to_wrap``. 
    273        """ 
    274        # Force the ``transaction`` to be not "in progress". 
    275        transaction._clean_up() 
    276        transaction._begin(retry_id=self.retry_id) 
    277 
    278        # Update the stored transaction IDs. 
    279        self.current_id = transaction._id 
    280        if self.retry_id is None: 
    281            self.retry_id = self.current_id 
    282        return self.to_wrap(transaction, *args, **kwargs) 
    283 
    284    def __call__(self, transaction: Transaction, *args, **kwargs): 
    285        """Execute the wrapped callable within a transaction. 
    286 
    287        Args: 
    288            transaction 
    289                (:class:`~google.cloud.firestore_v1.transaction.Transaction`): 
    290                A transaction to execute the callable within. 
    291            args (Tuple[Any, ...]): The extra positional arguments to pass 
    292                along to the wrapped callable. 
    293            kwargs (Dict[str, Any]): The extra keyword arguments to pass 
    294                along to the wrapped callable. 
    295 
    296        Returns: 
    297            Any: The result of the wrapped callable. 
    298 
    299        Raises: 
    300            ValueError: If the transaction does not succeed in 
    301                ``max_attempts``. 
    302        """ 
    303        self._reset() 
    304        retryable_exceptions = ( 
    305            (exceptions.Aborted) if not transaction._read_only else () 
    306        ) 
    307        last_exc = None 
    308 
    309        try: 
    310            for attempt in range(transaction._max_attempts): 
    311                result = self._pre_commit(transaction, *args, **kwargs) 
    312                try: 
    313                    transaction._commit() 
    314                    return result 
    315                except retryable_exceptions as exc: 
    316                    last_exc = exc 
    317                # Retry attempts that result in retryable exceptions 
    318                # Subsequent requests will use the failed transaction ID as part of 
    319                # the ``BeginTransactionRequest`` when restarting this transaction 
    320                # (via ``options.retry_transaction``). This preserves the "spot in 
    321                # line" of the transaction, so exponential backoff is not required 
    322                # in this case. 
    323            # retries exhausted 
    324            # wrap the last exception in a ValueError before raising 
    325            msg = _EXCEED_ATTEMPTS_TEMPLATE.format(transaction._max_attempts) 
    326            raise ValueError(msg) from last_exc 
    327        except BaseException:  # noqa: B901 
    328            # rollback the transaction on any error 
    329            # errors raised during _rollback will be chained to the original error through __context__ 
    330            transaction._rollback() 
    331            raise 
    332 
    333 
    334def transactional(to_wrap: Callable) -> _Transactional: 
    335    """Decorate a callable so that it runs in a transaction. 
    336 
    337    Args: 
    338        to_wrap 
    339            (Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]): 
    340            A callable that should be run (and retried) in a transaction. 
    341 
    342    Returns: 
    343        Callable[[:class:`~google.cloud.firestore_v1.transaction.Transaction`, ...], Any]: 
    344        the wrapped callable. 
    345    """ 
    346    return _Transactional(to_wrap)