1# Copyright 2024 Google LLC All rights reserved. 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Classes for representing vector queries for the Google Cloud Firestore API. 
    16""" 
    17from __future__ import annotations 
    18 
    19from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, Union 
    20 
    21from google.api_core import gapic_v1 
    22from google.api_core import retry as retries 
    23 
    24from google.cloud.firestore_v1.query_results import QueryResultsList 
    25from google.cloud.firestore_v1.base_query import ( 
    26    BaseQuery, 
    27    _collection_group_query_response_to_snapshot, 
    28    _query_response_to_snapshot, 
    29) 
    30from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery 
    31from google.cloud.firestore_v1.stream_generator import StreamGenerator 
    32 
    33# Types needed only for Type Hints 
    34if TYPE_CHECKING:  # pragma: NO COVER 
    35    from google.cloud.firestore_v1 import transaction 
    36    from google.cloud.firestore_v1.base_document import DocumentSnapshot 
    37    from google.cloud.firestore_v1.query_profile import ExplainMetrics 
    38    from google.cloud.firestore_v1.query_profile import ExplainOptions 
    39 
    40 
    41TVectorQuery = TypeVar("TVectorQuery", bound="VectorQuery") 
    42 
    43 
    44class VectorQuery(BaseVectorQuery): 
    45    """Represents a vector query to the Firestore API.""" 
    46 
    47    def __init__( 
    48        self, 
    49        nested_query: Union[BaseQuery, TVectorQuery], 
    50    ) -> None: 
    51        """Presents the vector query. 
    52        Args: 
    53            nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter. 
    54        """ 
    55        super(VectorQuery, self).__init__(nested_query) 
    56 
    57    def get( 
    58        self, 
    59        transaction=None, 
    60        retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, 
    61        timeout: Optional[float] = None, 
    62        *, 
    63        explain_options: Optional[ExplainOptions] = None, 
    64    ) -> QueryResultsList[DocumentSnapshot]: 
    65        """Runs the vector query. 
    66 
    67        This sends a ``RunQuery`` RPC and returns a list of document messages. 
    68 
    69        Args: 
    70            transaction 
    71                (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 
    72                An existing transaction that this query will run in. 
    73                If a ``transaction`` is used and it already has write operations 
    74                added, this method cannot be used (i.e. read-after-write is not 
    75                allowed). 
    76            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    77                should be retried.  Defaults to a system-specified policy. 
    78            timeout (float): The timeout for this request.  Defaults to a 
    79                system-specified value. 
    80            explain_options 
    81                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    82                Options to enable query profiling for this query. When set, 
    83                explain_metrics will be available on the returned generator. 
    84 
    85        Returns: 
    86            QueryResultsList[DocumentSnapshot]: The vector query results. 
    87        """ 
    88        explain_metrics: ExplainMetrics | None = None 
    89 
    90        result = self.stream( 
    91            transaction=transaction, 
    92            retry=retry, 
    93            timeout=timeout, 
    94            explain_options=explain_options, 
    95        ) 
    96        result_list = list(result) 
    97 
    98        if explain_options is None: 
    99            explain_metrics = None 
    100        else: 
    101            explain_metrics = result.get_explain_metrics() 
    102 
    103        return QueryResultsList(result_list, explain_options, explain_metrics) 
    104 
    105    def _get_stream_iterator(self, transaction, retry, timeout, explain_options=None): 
    106        """Helper method for :meth:`stream`.""" 
    107        request, expected_prefix, kwargs = self._prep_stream( 
    108            transaction, 
    109            retry, 
    110            timeout, 
    111            explain_options, 
    112        ) 
    113 
    114        response_iterator = self._client._firestore_api.run_query( 
    115            request=request, 
    116            metadata=self._client._rpc_metadata, 
    117            **kwargs, 
    118        ) 
    119 
    120        return response_iterator, expected_prefix 
    121 
    122    def _make_stream( 
    123        self, 
    124        transaction: Optional["transaction.Transaction"] = None, 
    125        retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, 
    126        timeout: Optional[float] = None, 
    127        explain_options: Optional[ExplainOptions] = None, 
    128    ) -> Generator[DocumentSnapshot, Any, Optional[ExplainMetrics]]: 
    129        """Reads the documents in the collection that match this query. 
    130 
    131        This sends a ``RunQuery`` RPC and then returns a generator which 
    132        consumes each document returned in the stream of ``RunQueryResponse`` 
    133        messages. 
    134 
    135        If a ``transaction`` is used and it already has write operations 
    136        added, this method cannot be used (i.e. read-after-write is not 
    137        allowed). 
    138 
    139        Args: 
    140            transaction 
    141                (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 
    142                An existing transaction that this query will run in. 
    143            retry (Optional[google.api_core.retry.Retry]): Designation of what 
    144                errors, if any, should be retried.  Defaults to a 
    145                system-specified policy. 
    146            timeout (Optional[float]): The timeout for this request.  Defaults 
    147            to a system-specified value. 
    148            explain_options 
    149                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    150                Options to enable query profiling for this query. When set, 
    151                explain_metrics will be available on the returned generator. 
    152 
    153        Yields: 
    154            DocumentSnapshot: 
    155            The next document that fulfills the query. 
    156 
    157        Returns: 
    158            ([google.cloud.firestore_v1.types.query_profile.ExplainMetrtics | None]): 
    159            The results of query profiling, if received from the service. 
    160        """ 
    161        metrics: ExplainMetrics | None = None 
    162 
    163        response_iterator, expected_prefix = self._get_stream_iterator( 
    164            transaction, 
    165            retry, 
    166            timeout, 
    167            explain_options, 
    168        ) 
    169 
    170        while True: 
    171            response = next(response_iterator, None) 
    172 
    173            if response is None:  # EOI 
    174                break 
    175 
    176            if metrics is None and response.explain_metrics: 
    177                metrics = response.explain_metrics 
    178 
    179            if self._nested_query._all_descendants: 
    180                snapshot = _collection_group_query_response_to_snapshot( 
    181                    response, self._nested_query._parent 
    182                ) 
    183            else: 
    184                snapshot = _query_response_to_snapshot( 
    185                    response, self._nested_query._parent, expected_prefix 
    186                ) 
    187            if snapshot is not None: 
    188                yield snapshot 
    189 
    190        return metrics 
    191 
    192    def stream( 
    193        self, 
    194        transaction: Optional["transaction.Transaction"] = None, 
    195        retry: retries.Retry | object | None = gapic_v1.method.DEFAULT, 
    196        timeout: Optional[float] = None, 
    197        *, 
    198        explain_options: Optional[ExplainOptions] = None, 
    199    ) -> StreamGenerator[DocumentSnapshot]: 
    200        """Reads the documents in the collection that match this query. 
    201 
    202        This sends a ``RunQuery`` RPC and then returns a generator which 
    203        consumes each document returned in the stream of ``RunQueryResponse`` 
    204        messages. 
    205 
    206        If a ``transaction`` is used and it already has write operations 
    207        added, this method cannot be used (i.e. read-after-write is not 
    208        allowed). 
    209 
    210        Args: 
    211            transaction 
    212                (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 
    213                An existing transaction that this query will run in. 
    214            retry (Optional[google.api_core.retry.Retry]): Designation of what 
    215                errors, if any, should be retried.  Defaults to a 
    216                system-specified policy. 
    217            timeout (Optinal[float]): The timeout for this request.  Defaults 
    218            to a system-specified value. 
    219            explain_options 
    220                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    221                Options to enable query profiling for this query. When set, 
    222                explain_metrics will be available on the returned generator. 
    223 
    224        Returns: 
    225            `StreamGenerator[DocumentSnapshot]`: A generator of the query results. 
    226        """ 
    227        inner_generator = self._make_stream( 
    228            transaction=transaction, 
    229            retry=retry, 
    230            timeout=timeout, 
    231            explain_options=explain_options, 
    232        ) 
    233        return StreamGenerator(inner_generator, explain_options)