1# Copyright 2024 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Classes for representing vector queries for the Google Cloud Firestore API.
16"""
17from __future__ import annotations
18
19import abc
20from abc import ABC
21from enum import Enum
22from typing import TYPE_CHECKING, Any, Coroutine, Optional, Sequence, Tuple, Union
23
24from google.api_core import gapic_v1
25from google.api_core import retry as retries
26
27from google.cloud.firestore_v1 import _helpers
28from google.cloud.firestore_v1.types import query
29from google.cloud.firestore_v1.vector import Vector
30
31if TYPE_CHECKING: # pragma: NO COVER
32 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
33 from google.cloud.firestore_v1.base_document import DocumentSnapshot
34 from google.cloud.firestore_v1.query_profile import ExplainOptions
35 from google.cloud.firestore_v1.query_results import QueryResultsList
36 from google.cloud.firestore_v1.stream_generator import StreamGenerator
37
38
39class DistanceMeasure(Enum):
40 EUCLIDEAN = 1
41 COSINE = 2
42 DOT_PRODUCT = 3
43
44
45class BaseVectorQuery(ABC):
46 """Represents a vector query to the Firestore API."""
47
48 def __init__(self, nested_query) -> None:
49 self._nested_query = nested_query
50 self._collection_ref = nested_query._parent
51 self._vector_field: Optional[str] = None
52 self._query_vector: Optional[Vector] = None
53 self._limit: Optional[int] = None
54 self._distance_measure: Optional[DistanceMeasure] = None
55 self._distance_result_field: Optional[str] = None
56 self._distance_threshold: Optional[float] = None
57
58 @property
59 def _client(self):
60 return self._collection_ref._client
61
62 def _to_protobuf(self) -> query.StructuredQuery:
63 pb = query.StructuredQuery()
64
65 distance_measure_proto = None
66 if self._distance_measure == DistanceMeasure.EUCLIDEAN:
67 distance_measure_proto = (
68 query.StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN
69 )
70 elif self._distance_measure == DistanceMeasure.COSINE:
71 distance_measure_proto = (
72 query.StructuredQuery.FindNearest.DistanceMeasure.COSINE
73 )
74 elif self._distance_measure == DistanceMeasure.DOT_PRODUCT:
75 distance_measure_proto = (
76 query.StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT
77 )
78 else:
79 raise ValueError("Invalid distance_measure")
80
81 # Coerce ints to floats as required by the protobuf.
82 distance_threshold_proto = None
83 if self._distance_threshold is not None:
84 distance_threshold_proto = float(self._distance_threshold)
85
86 pb = self._nested_query._to_protobuf()
87 pb.find_nearest = query.StructuredQuery.FindNearest(
88 vector_field=query.StructuredQuery.FieldReference(
89 field_path=self._vector_field
90 ),
91 query_vector=_helpers.encode_value(self._query_vector),
92 distance_measure=distance_measure_proto,
93 limit=self._limit,
94 distance_result_field=self._distance_result_field,
95 distance_threshold=distance_threshold_proto,
96 )
97 return pb
98
99 def _prep_stream(
100 self,
101 transaction=None,
102 retry: Union[retries.Retry, retries.AsyncRetry, object, None] = None,
103 timeout: Optional[float] = None,
104 explain_options: Optional[ExplainOptions] = None,
105 ) -> Tuple[dict, str, dict]:
106 parent_path, expected_prefix = self._collection_ref._parent_info()
107 request = {
108 "parent": parent_path,
109 "structured_query": self._to_protobuf(),
110 "transaction": _helpers.get_transaction_id(transaction),
111 }
112 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
113
114 if explain_options is not None:
115 request["explain_options"] = explain_options._to_dict()
116
117 return request, expected_prefix, kwargs
118
119 @abc.abstractmethod
120 def get(
121 self,
122 transaction=None,
123 retry: retries.Retry
124 | retries.AsyncRetry
125 | object
126 | None = gapic_v1.method.DEFAULT,
127 timeout: Optional[float] = None,
128 *,
129 explain_options: Optional[ExplainOptions] = None,
130 ) -> (
131 QueryResultsList[DocumentSnapshot]
132 | Coroutine[Any, Any, QueryResultsList[DocumentSnapshot]]
133 ):
134 """Runs the vector query."""
135 raise NotImplementedError
136
137 def find_nearest(
138 self,
139 vector_field: str,
140 query_vector: Union[Vector, Sequence[float]],
141 limit: int,
142 distance_measure: DistanceMeasure,
143 *,
144 distance_result_field: Optional[str] = None,
145 distance_threshold: Optional[float] = None,
146 ):
147 """Finds the closest vector embeddings to the given query vector."""
148 if not isinstance(query_vector, Vector):
149 self._query_vector = Vector(query_vector)
150 else:
151 self._query_vector = query_vector
152 self._vector_field = vector_field
153 self._limit = limit
154 self._distance_measure = distance_measure
155 self._distance_result_field = distance_result_field
156 self._distance_threshold = distance_threshold
157 return self
158
159 def stream(
160 self,
161 transaction=None,
162 retry: retries.Retry
163 | retries.AsyncRetry
164 | object
165 | None = gapic_v1.method.DEFAULT,
166 timeout: Optional[float] = None,
167 *,
168 explain_options: Optional[ExplainOptions] = None,
169 ) -> StreamGenerator[DocumentSnapshot] | AsyncStreamGenerator[DocumentSnapshot]:
170 """Reads the documents in the collection that match this query."""
171 raise NotImplementedError