1# Copyright 2024 Google LLC All rights reserved. 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Classes for representing vector queries for the Google Cloud Firestore API. 
    16""" 
    17from __future__ import annotations 
    18 
    19import abc 
    20from abc import ABC 
    21from enum import Enum 
    22from typing import TYPE_CHECKING, Any, Coroutine, Optional, Sequence, Tuple, Union 
    23 
    24from google.api_core import gapic_v1 
    25from google.api_core import retry as retries 
    26 
    27from google.cloud.firestore_v1 import _helpers 
    28from google.cloud.firestore_v1.types import query 
    29from google.cloud.firestore_v1.vector import Vector 
    30 
    31if TYPE_CHECKING:  # pragma: NO COVER 
    32    from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator 
    33    from google.cloud.firestore_v1.base_document import DocumentSnapshot 
    34    from google.cloud.firestore_v1.query_profile import ExplainOptions 
    35    from google.cloud.firestore_v1.query_results import QueryResultsList 
    36    from google.cloud.firestore_v1.stream_generator import StreamGenerator 
    37 
    38 
    39class DistanceMeasure(Enum): 
    40    EUCLIDEAN = 1 
    41    COSINE = 2 
    42    DOT_PRODUCT = 3 
    43 
    44 
    45class BaseVectorQuery(ABC): 
    46    """Represents a vector query to the Firestore API.""" 
    47 
    48    def __init__(self, nested_query) -> None: 
    49        self._nested_query = nested_query 
    50        self._collection_ref = nested_query._parent 
    51        self._vector_field: Optional[str] = None 
    52        self._query_vector: Optional[Vector] = None 
    53        self._limit: Optional[int] = None 
    54        self._distance_measure: Optional[DistanceMeasure] = None 
    55        self._distance_result_field: Optional[str] = None 
    56        self._distance_threshold: Optional[float] = None 
    57 
    58    @property 
    59    def _client(self): 
    60        return self._collection_ref._client 
    61 
    62    def _to_protobuf(self) -> query.StructuredQuery: 
    63        pb = query.StructuredQuery() 
    64 
    65        distance_measure_proto = None 
    66        if self._distance_measure == DistanceMeasure.EUCLIDEAN: 
    67            distance_measure_proto = ( 
    68                query.StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN 
    69            ) 
    70        elif self._distance_measure == DistanceMeasure.COSINE: 
    71            distance_measure_proto = ( 
    72                query.StructuredQuery.FindNearest.DistanceMeasure.COSINE 
    73            ) 
    74        elif self._distance_measure == DistanceMeasure.DOT_PRODUCT: 
    75            distance_measure_proto = ( 
    76                query.StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT 
    77            ) 
    78        else: 
    79            raise ValueError("Invalid distance_measure") 
    80 
    81        # Coerce ints to floats as required by the protobuf. 
    82        distance_threshold_proto = None 
    83        if self._distance_threshold is not None: 
    84            distance_threshold_proto = float(self._distance_threshold) 
    85 
    86        pb = self._nested_query._to_protobuf() 
    87        pb.find_nearest = query.StructuredQuery.FindNearest( 
    88            vector_field=query.StructuredQuery.FieldReference( 
    89                field_path=self._vector_field 
    90            ), 
    91            query_vector=_helpers.encode_value(self._query_vector), 
    92            distance_measure=distance_measure_proto, 
    93            limit=self._limit, 
    94            distance_result_field=self._distance_result_field, 
    95            distance_threshold=distance_threshold_proto, 
    96        ) 
    97        return pb 
    98 
    99    def _prep_stream( 
    100        self, 
    101        transaction=None, 
    102        retry: Union[retries.Retry, retries.AsyncRetry, object, None] = None, 
    103        timeout: Optional[float] = None, 
    104        explain_options: Optional[ExplainOptions] = None, 
    105    ) -> Tuple[dict, str, dict]: 
    106        parent_path, expected_prefix = self._collection_ref._parent_info() 
    107        request = { 
    108            "parent": parent_path, 
    109            "structured_query": self._to_protobuf(), 
    110            "transaction": _helpers.get_transaction_id(transaction), 
    111        } 
    112        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    113 
    114        if explain_options is not None: 
    115            request["explain_options"] = explain_options._to_dict() 
    116 
    117        return request, expected_prefix, kwargs 
    118 
    119    @abc.abstractmethod 
    120    def get( 
    121        self, 
    122        transaction=None, 
    123        retry: retries.Retry 
    124        | retries.AsyncRetry 
    125        | object 
    126        | None = gapic_v1.method.DEFAULT, 
    127        timeout: Optional[float] = None, 
    128        *, 
    129        explain_options: Optional[ExplainOptions] = None, 
    130    ) -> ( 
    131        QueryResultsList[DocumentSnapshot] 
    132        | Coroutine[Any, Any, QueryResultsList[DocumentSnapshot]] 
    133    ): 
    134        """Runs the vector query.""" 
    135        raise NotImplementedError 
    136 
    137    def find_nearest( 
    138        self, 
    139        vector_field: str, 
    140        query_vector: Union[Vector, Sequence[float]], 
    141        limit: int, 
    142        distance_measure: DistanceMeasure, 
    143        *, 
    144        distance_result_field: Optional[str] = None, 
    145        distance_threshold: Optional[float] = None, 
    146    ): 
    147        """Finds the closest vector embeddings to the given query vector.""" 
    148        if not isinstance(query_vector, Vector): 
    149            self._query_vector = Vector(query_vector) 
    150        else: 
    151            self._query_vector = query_vector 
    152        self._vector_field = vector_field 
    153        self._limit = limit 
    154        self._distance_measure = distance_measure 
    155        self._distance_result_field = distance_result_field 
    156        self._distance_threshold = distance_threshold 
    157        return self 
    158 
    159    def stream( 
    160        self, 
    161        transaction=None, 
    162        retry: retries.Retry 
    163        | retries.AsyncRetry 
    164        | object 
    165        | None = gapic_v1.method.DEFAULT, 
    166        timeout: Optional[float] = None, 
    167        *, 
    168        explain_options: Optional[ExplainOptions] = None, 
    169    ) -> StreamGenerator[DocumentSnapshot] | AsyncStreamGenerator[DocumentSnapshot]: 
    170        """Reads the documents in the collection that match this query.""" 
    171        raise NotImplementedError