1# Copyright 2024 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Classes for representing vector queries for the Google Cloud Firestore API.
16"""
17from __future__ import annotations
18
19from typing import TYPE_CHECKING, Any, Generator, Optional, TypeVar, Union
20
21from google.api_core import gapic_v1
22from google.api_core import retry as retries
23
24from google.cloud.firestore_v1.query_results import QueryResultsList
25from google.cloud.firestore_v1.base_query import (
26 BaseQuery,
27 _collection_group_query_response_to_snapshot,
28 _query_response_to_snapshot,
29)
30from google.cloud.firestore_v1.base_vector_query import BaseVectorQuery
31from google.cloud.firestore_v1.stream_generator import StreamGenerator
32
33# Types needed only for Type Hints
34if TYPE_CHECKING: # pragma: NO COVER
35 from google.cloud.firestore_v1 import transaction
36 from google.cloud.firestore_v1.base_document import DocumentSnapshot
37 from google.cloud.firestore_v1.query_profile import ExplainMetrics
38 from google.cloud.firestore_v1.query_profile import ExplainOptions
39
40
41TVectorQuery = TypeVar("TVectorQuery", bound="VectorQuery")
42
43
44class VectorQuery(BaseVectorQuery):
45 """Represents a vector query to the Firestore API."""
46
47 def __init__(
48 self,
49 nested_query: Union[BaseQuery, TVectorQuery],
50 ) -> None:
51 """Presents the vector query.
52 Args:
53 nested_query (BaseQuery | VectorQuery): the base query to apply as the prefilter.
54 """
55 super(VectorQuery, self).__init__(nested_query)
56
57 def get(
58 self,
59 transaction=None,
60 retry: retries.Retry | object | None = gapic_v1.method.DEFAULT,
61 timeout: Optional[float] = None,
62 *,
63 explain_options: Optional[ExplainOptions] = None,
64 ) -> QueryResultsList[DocumentSnapshot]:
65 """Runs the vector query.
66
67 This sends a ``RunQuery`` RPC and returns a list of document messages.
68
69 Args:
70 transaction
71 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
72 An existing transaction that this query will run in.
73 If a ``transaction`` is used and it already has write operations
74 added, this method cannot be used (i.e. read-after-write is not
75 allowed).
76 retry (google.api_core.retry.Retry): Designation of what errors, if any,
77 should be retried. Defaults to a system-specified policy.
78 timeout (float): The timeout for this request. Defaults to a
79 system-specified value.
80 explain_options
81 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
82 Options to enable query profiling for this query. When set,
83 explain_metrics will be available on the returned generator.
84
85 Returns:
86 QueryResultsList[DocumentSnapshot]: The vector query results.
87 """
88 explain_metrics: ExplainMetrics | None = None
89
90 result = self.stream(
91 transaction=transaction,
92 retry=retry,
93 timeout=timeout,
94 explain_options=explain_options,
95 )
96 result_list = list(result)
97
98 if explain_options is None:
99 explain_metrics = None
100 else:
101 explain_metrics = result.get_explain_metrics()
102
103 return QueryResultsList(result_list, explain_options, explain_metrics)
104
105 def _get_stream_iterator(self, transaction, retry, timeout, explain_options=None):
106 """Helper method for :meth:`stream`."""
107 request, expected_prefix, kwargs = self._prep_stream(
108 transaction,
109 retry,
110 timeout,
111 explain_options,
112 )
113
114 response_iterator = self._client._firestore_api.run_query(
115 request=request,
116 metadata=self._client._rpc_metadata,
117 **kwargs,
118 )
119
120 return response_iterator, expected_prefix
121
122 def _make_stream(
123 self,
124 transaction: Optional["transaction.Transaction"] = None,
125 retry: retries.Retry | object | None = gapic_v1.method.DEFAULT,
126 timeout: Optional[float] = None,
127 explain_options: Optional[ExplainOptions] = None,
128 ) -> Generator[DocumentSnapshot, Any, Optional[ExplainMetrics]]:
129 """Reads the documents in the collection that match this query.
130
131 This sends a ``RunQuery`` RPC and then returns a generator which
132 consumes each document returned in the stream of ``RunQueryResponse``
133 messages.
134
135 If a ``transaction`` is used and it already has write operations
136 added, this method cannot be used (i.e. read-after-write is not
137 allowed).
138
139 Args:
140 transaction
141 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
142 An existing transaction that this query will run in.
143 retry (Optional[google.api_core.retry.Retry]): Designation of what
144 errors, if any, should be retried. Defaults to a
145 system-specified policy.
146 timeout (Optional[float]): The timeout for this request. Defaults
147 to a system-specified value.
148 explain_options
149 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
150 Options to enable query profiling for this query. When set,
151 explain_metrics will be available on the returned generator.
152
153 Yields:
154 DocumentSnapshot:
155 The next document that fulfills the query.
156
157 Returns:
158 ([google.cloud.firestore_v1.types.query_profile.ExplainMetrtics | None]):
159 The results of query profiling, if received from the service.
160 """
161 metrics: ExplainMetrics | None = None
162
163 response_iterator, expected_prefix = self._get_stream_iterator(
164 transaction,
165 retry,
166 timeout,
167 explain_options,
168 )
169
170 while True:
171 response = next(response_iterator, None)
172
173 if response is None: # EOI
174 break
175
176 if metrics is None and response.explain_metrics:
177 metrics = response.explain_metrics
178
179 if self._nested_query._all_descendants:
180 snapshot = _collection_group_query_response_to_snapshot(
181 response, self._nested_query._parent
182 )
183 else:
184 snapshot = _query_response_to_snapshot(
185 response, self._nested_query._parent, expected_prefix
186 )
187 if snapshot is not None:
188 yield snapshot
189
190 return metrics
191
192 def stream(
193 self,
194 transaction: Optional["transaction.Transaction"] = None,
195 retry: retries.Retry | object | None = gapic_v1.method.DEFAULT,
196 timeout: Optional[float] = None,
197 *,
198 explain_options: Optional[ExplainOptions] = None,
199 ) -> StreamGenerator[DocumentSnapshot]:
200 """Reads the documents in the collection that match this query.
201
202 This sends a ``RunQuery`` RPC and then returns a generator which
203 consumes each document returned in the stream of ``RunQueryResponse``
204 messages.
205
206 If a ``transaction`` is used and it already has write operations
207 added, this method cannot be used (i.e. read-after-write is not
208 allowed).
209
210 Args:
211 transaction
212 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
213 An existing transaction that this query will run in.
214 retry (Optional[google.api_core.retry.Retry]): Designation of what
215 errors, if any, should be retried. Defaults to a
216 system-specified policy.
217 timeout (Optinal[float]): The timeout for this request. Defaults
218 to a system-specified value.
219 explain_options
220 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
221 Options to enable query profiling for this query. When set,
222 explain_metrics will be available on the returned generator.
223
224 Returns:
225 `StreamGenerator[DocumentSnapshot]`: A generator of the query results.
226 """
227 inner_generator = self._make_stream(
228 transaction=transaction,
229 retry=retry,
230 timeout=timeout,
231 explain_options=explain_options,
232 )
233 return StreamGenerator(inner_generator, explain_options)