1# Copyright 2024 Google LLC All rights reserved. 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15from __future__ import annotations 
    16 
    17from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional, TypeVar, Union 
    18 
    19from google.api_core import gapic_v1 
    20from google.api_core import retry as retries 
    21 
    22from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator 
    23from google.cloud.firestore_v1.base_query import ( 
    24    BaseQuery, 
    25    _collection_group_query_response_to_snapshot, 
    26    _query_response_to_snapshot, 
    27) 
    28from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery 
    29from google.cloud.firestore_v1.query_results import QueryResultsList 
    30 
    31# Types needed only for Type Hints 
    32if TYPE_CHECKING:  # pragma: NO COVER 
    33    from google.cloud.firestore_v1.base_document import DocumentSnapshot 
    34    from google.cloud.firestore_v1.query_profile import ExplainMetrics, ExplainOptions 
    35    from google.cloud.firestore_v1 import transaction 
    36    import google.cloud.firestore_v1.types.query_profile as query_profile_pb 
    37 
    38TAsyncVectorQuery = TypeVar("TAsyncVectorQuery", bound="AsyncVectorQuery") 
    39 
    40 
    41class AsyncVectorQuery(BaseVectorQuery): 
    42    """Represents an async vector query to the Firestore API.""" 
    43 
    44    def __init__( 
    45        self, 
    46        nested_query: Union[BaseQuery, TAsyncVectorQuery], 
    47    ) -> None: 
    48        """Presents the vector query. 
    49        Args: 
    50            nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter. 
    51        """ 
    52        super(AsyncVectorQuery, self).__init__(nested_query) 
    53 
    54    async def get( 
    55        self, 
    56        transaction=None, 
    57        retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, 
    58        timeout: Optional[float] = None, 
    59        *, 
    60        explain_options: Optional[ExplainOptions] = None, 
    61    ) -> QueryResultsList[DocumentSnapshot]: 
    62        """Runs the vector query. 
    63 
    64        This sends a ``RunQuery`` RPC and returns a list of document messages. 
    65 
    66        Args: 
    67            transaction 
    68                (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 
    69                An existing transaction that this query will run in. 
    70                If a ``transaction`` is used and it already has write operations 
    71                added, this method cannot be used (i.e. read-after-write is not 
    72                allowed). 
    73            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    74                should be retried.  Defaults to a system-specified policy. 
    75            timeout (float): The timeout for this request.  Defaults to a 
    76                system-specified value. 
    77            explain_options 
    78                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    79                Options to enable query profiling for this query. When set, 
    80                explain_metrics will be available on the returned generator. 
    81 
    82        Returns: 
    83            QueryResultsList[DocumentSnapshot]: The documents in the collection 
    84            that match this query. 
    85        """ 
    86        explain_metrics: ExplainMetrics | None = None 
    87 
    88        stream_result = self.stream( 
    89            transaction=transaction, 
    90            retry=retry, 
    91            timeout=timeout, 
    92            explain_options=explain_options, 
    93        ) 
    94        try: 
    95            result = [snapshot async for snapshot in stream_result] 
    96 
    97            if explain_options is None: 
    98                explain_metrics = None 
    99            else: 
    100                explain_metrics = await stream_result.get_explain_metrics() 
    101        finally: 
    102            await stream_result.aclose() 
    103 
    104        return QueryResultsList(result, explain_options, explain_metrics) 
    105 
    106    async def _make_stream( 
    107        self, 
    108        transaction: Optional[transaction.Transaction] = None, 
    109        retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, 
    110        timeout: Optional[float] = None, 
    111        explain_options: Optional[ExplainOptions] = None, 
    112    ) -> AsyncGenerator[DocumentSnapshot | query_profile_pb.ExplainMetrics, Any]: 
    113        """Internal method for stream(). Read the documents in the collection 
    114        that match this query. 
    115 
    116        This sends a ``RunQuery`` RPC and then returns a generator which 
    117        consumes each document returned in the stream of ``RunQueryResponse`` 
    118        messages. 
    119 
    120        If a ``transaction`` is used and it already has write operations 
    121        added, this method cannot be used (i.e. read-after-write is not 
    122        allowed). 
    123 
    124        Args: 
    125            transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ 
    126                Transaction`]): 
    127                An existing transaction that the query will run in. 
    128            retry (Optional[google.api_core.retry.Retry]): Designation of what 
    129                errors, if any, should be retried.  Defaults to a 
    130                system-specified policy. 
    131            timeout (Optional[float]): The timeout for this request. Defaults 
    132                to a system-specified value. 
    133            explain_options 
    134                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    135                Options to enable query profiling for this query. When set, 
    136                explain_metrics will be available on the returned generator. 
    137 
    138        Yields: 
    139            [:class:`~google.cloud.firestore_v1.base_document.DocumentSnapshot` \ 
    140                | google.cloud.firestore_v1.types.query_profile.ExplainMetrtics]: 
    141            The next document that fulfills the query. Query results will be 
    142            yielded as `DocumentSnapshot`. When the result contains returned 
    143            explain metrics, yield `query_profile_pb.ExplainMetrics` individually. 
    144        """ 
    145        request, expected_prefix, kwargs = self._prep_stream( 
    146            transaction, 
    147            retry, 
    148            timeout, 
    149            explain_options, 
    150        ) 
    151 
    152        response_iterator = await self._client._firestore_api.run_query( 
    153            request=request, 
    154            metadata=self._client._rpc_metadata, 
    155            **kwargs, 
    156        ) 
    157        async for response in response_iterator: 
    158            if self._nested_query._all_descendants: 
    159                snapshot = _collection_group_query_response_to_snapshot( 
    160                    response, self._nested_query._parent 
    161                ) 
    162            else: 
    163                snapshot = _query_response_to_snapshot( 
    164                    response, self._nested_query._parent, expected_prefix 
    165                ) 
    166            if snapshot is not None: 
    167                yield snapshot 
    168 
    169            if response.explain_metrics: 
    170                metrics = response.explain_metrics 
    171                yield metrics 
    172 
    173    def stream( 
    174        self, 
    175        transaction=None, 
    176        retry: retries.AsyncRetry | object | None = gapic_v1.method.DEFAULT, 
    177        timeout: Optional[float] = None, 
    178        *, 
    179        explain_options: Optional[ExplainOptions] = None, 
    180    ) -> AsyncStreamGenerator[DocumentSnapshot]: 
    181        """Reads the documents in the collection that match this query. 
    182 
    183        This sends a ``RunQuery`` RPC and then returns an iterator which 
    184        consumes each document returned in the stream of ``RunQueryResponse`` 
    185        messages. 
    186 
    187        If a ``transaction`` is used and it already has write operations 
    188        added, this method cannot be used (i.e. read-after-write is not 
    189        allowed). 
    190 
    191        Args: 
    192            transaction 
    193                (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 
    194                An existing transaction that this query will run in. 
    195            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    196                should be retried.  Defaults to a system-specified policy. 
    197            timeout (float): The timeout for this request.  Defaults to a 
    198                system-specified value. 
    199            explain_options 
    200                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    201                Options to enable query profiling for this query. When set, 
    202                explain_metrics will be available on the returned generator. 
    203 
    204        Returns: 
    205            `AsyncStreamGenerator[DocumentSnapshot]`: 
    206            An asynchronous generator of the queryresults. 
    207        """ 
    208 
    209        inner_generator = self._make_stream( 
    210            transaction=transaction, 
    211            retry=retry, 
    212            timeout=timeout, 
    213            explain_options=explain_options, 
    214        ) 
    215        return AsyncStreamGenerator(inner_generator, explain_options)