1# Copyright 2023 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Classes for representing aggregation queries for the Google Cloud Firestore API.
16
17A :class:`~google.cloud.firestore_v1.aggregation.AggregationQuery` can be created directly from
18a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be
19a more common way to create an aggregation query than direct usage of the constructor.
20"""
21from __future__ import annotations
22
23import abc
24
25from abc import ABC
26from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union
27
28from google.api_core import gapic_v1
29from google.api_core import retry as retries
30
31from google.cloud.firestore_v1 import _helpers
32from google.cloud.firestore_v1.field_path import FieldPath
33from google.cloud.firestore_v1.types import (
34 StructuredAggregationQuery,
35)
36
37# Types needed only for Type Hints
38if TYPE_CHECKING: # pragma: NO COVER
39 from google.cloud.firestore_v1 import transaction
40 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
41 from google.cloud.firestore_v1.query_profile import ExplainOptions
42 from google.cloud.firestore_v1.query_results import QueryResultsList
43 from google.cloud.firestore_v1.stream_generator import (
44 StreamGenerator,
45 )
46
47 import datetime
48
49
50class AggregationResult(object):
51 """
52 A class representing result from Aggregation Query
53 :type alias: str
54 :param alias: The alias for the aggregation.
55 :type value: int
56 :param value: The resulting value from the aggregation.
57 :type read_time:
58 :param value: The resulting read_time
59 """
60
61 def __init__(self, alias: str, value: float, read_time=None):
62 self.alias = alias
63 self.value = value
64 self.read_time = read_time
65
66 def __repr__(self):
67 return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>"
68
69
70class BaseAggregation(ABC):
71 def __init__(self, alias: str | None = None):
72 self.alias = alias
73
74 @abc.abstractmethod
75 def _to_protobuf(self):
76 """Convert this instance to the protobuf representation"""
77
78
79class CountAggregation(BaseAggregation):
80 def __init__(self, alias: str | None = None):
81 super(CountAggregation, self).__init__(alias=alias)
82
83 def _to_protobuf(self):
84 """Convert this instance to the protobuf representation"""
85 aggregation_pb = StructuredAggregationQuery.Aggregation()
86 if self.alias:
87 aggregation_pb.alias = self.alias
88 aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
89 return aggregation_pb
90
91
92class SumAggregation(BaseAggregation):
93 def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
94 # convert field path to string if needed
95 field_str = (
96 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
97 )
98 self.field_ref: str = field_str
99 super(SumAggregation, self).__init__(alias=alias)
100
101 def _to_protobuf(self):
102 """Convert this instance to the protobuf representation"""
103 aggregation_pb = StructuredAggregationQuery.Aggregation()
104 if self.alias:
105 aggregation_pb.alias = self.alias
106 aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
107 aggregation_pb.sum.field.field_path = self.field_ref
108 return aggregation_pb
109
110
111class AvgAggregation(BaseAggregation):
112 def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
113 # convert field path to string if needed
114 field_str = (
115 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
116 )
117 self.field_ref: str = field_str
118 super(AvgAggregation, self).__init__(alias=alias)
119
120 def _to_protobuf(self):
121 """Convert this instance to the protobuf representation"""
122 aggregation_pb = StructuredAggregationQuery.Aggregation()
123 if self.alias:
124 aggregation_pb.alias = self.alias
125 aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
126 aggregation_pb.avg.field.field_path = self.field_ref
127 return aggregation_pb
128
129
130def _query_response_to_result(
131 response_pb,
132) -> List[AggregationResult]:
133 results = [
134 AggregationResult(
135 alias=key,
136 value=response_pb.result.aggregate_fields[key].integer_value
137 or response_pb.result.aggregate_fields[key].double_value,
138 read_time=response_pb.read_time,
139 )
140 for key in response_pb.result.aggregate_fields.pb.keys()
141 ]
142
143 return results
144
145
146class BaseAggregationQuery(ABC):
147 """Represents an aggregation query to the Firestore API."""
148
149 def __init__(self, nested_query, alias: str | None = None) -> None:
150 self._nested_query = nested_query
151 self._alias = alias
152 self._collection_ref = nested_query._parent
153 self._aggregations: List[BaseAggregation] = []
154
155 @property
156 def _client(self):
157 return self._collection_ref._client
158
159 def count(self, alias: str | None = None):
160 """
161 Adds a count over the nested query
162 """
163 count_aggregation = CountAggregation(alias=alias)
164 self._aggregations.append(count_aggregation)
165 return self
166
167 def sum(self, field_ref: str | FieldPath, alias: str | None = None):
168 """
169 Adds a sum over the nested query
170 """
171 sum_aggregation = SumAggregation(field_ref, alias=alias)
172 self._aggregations.append(sum_aggregation)
173 return self
174
175 def avg(self, field_ref: str | FieldPath, alias: str | None = None):
176 """
177 Adds an avg over the nested query
178 """
179 avg_aggregation = AvgAggregation(field_ref, alias=alias)
180 self._aggregations.append(avg_aggregation)
181 return self
182
183 def add_aggregation(self, aggregation: BaseAggregation) -> None:
184 """
185 Adds an aggregation operation to the nested query
186
187 :type aggregation: :class:`google.cloud.firestore_v1.aggregation.BaseAggregation`
188 :param aggregation: An aggregation operation, e.g. a CountAggregation
189 """
190 self._aggregations.append(aggregation)
191
192 def add_aggregations(self, aggregations: List[BaseAggregation]) -> None:
193 """
194 Adds a list of aggregations to the nested query
195
196 :type aggregations: list
197 :param aggregations: a list of aggregation operations
198 """
199 self._aggregations.extend(aggregations)
200
201 def _to_protobuf(self) -> StructuredAggregationQuery:
202 pb = StructuredAggregationQuery()
203 pb.structured_query = self._nested_query._to_protobuf()
204
205 for aggregation in self._aggregations:
206 aggregation_pb = aggregation._to_protobuf()
207 pb.aggregations.append(aggregation_pb)
208 return pb
209
210 def _prep_stream(
211 self,
212 transaction=None,
213 retry: Union[retries.Retry, retries.AsyncRetry, None, object] = None,
214 timeout: float | None = None,
215 explain_options: Optional[ExplainOptions] = None,
216 read_time: Optional[datetime.datetime] = None,
217 ) -> Tuple[dict, dict]:
218 parent_path, expected_prefix = self._collection_ref._parent_info()
219 request = {
220 "parent": parent_path,
221 "structured_aggregation_query": self._to_protobuf(),
222 "transaction": _helpers.get_transaction_id(transaction),
223 }
224 if explain_options:
225 request["explain_options"] = explain_options._to_dict()
226 if read_time is not None:
227 request["read_time"] = read_time
228 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
229
230 return request, kwargs
231
232 @abc.abstractmethod
233 def get(
234 self,
235 transaction=None,
236 retry: Union[
237 retries.Retry, retries.AsyncRetry, None, object
238 ] = gapic_v1.method.DEFAULT,
239 timeout: float | None = None,
240 *,
241 explain_options: Optional[ExplainOptions] = None,
242 read_time: Optional[datetime.datetime] = None,
243 ) -> (
244 QueryResultsList[AggregationResult]
245 | Coroutine[Any, Any, List[List[AggregationResult]]]
246 ):
247 """Runs the aggregation query.
248
249 This sends a ``RunAggregationQuery`` RPC and returns a list of
250 aggregation results in the stream of ``RunAggregationQueryResponse``
251 messages.
252
253 Args:
254 transaction
255 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
256 An existing transaction that this query will run in.
257 If a ``transaction`` is used and it already has write operations
258 added, this method cannot be used (i.e. read-after-write is not
259 allowed).
260 retry (google.api_core.retry.Retry): Designation of what errors, if any,
261 should be retried. Defaults to a system-specified policy.
262 timeout (float): The timeout for this request. Defaults to a
263 system-specified value.
264 explain_options
265 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
266 Options to enable query profiling for this query. When set,
267 explain_metrics will be available on the returned generator.
268 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
269 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
270 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
271 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
272
273 Returns:
274 (QueryResultsList[List[AggregationResult]] | Coroutine[Any, Any, List[List[AggregationResult]]]):
275 The aggregation query results.
276 """
277
278 @abc.abstractmethod
279 def stream(
280 self,
281 transaction: Optional[transaction.Transaction] = None,
282 retry: retries.Retry
283 | retries.AsyncRetry
284 | object
285 | None = gapic_v1.method.DEFAULT,
286 timeout: Optional[float] = None,
287 *,
288 explain_options: Optional[ExplainOptions] = None,
289 read_time: Optional[datetime.datetime] = None,
290 ) -> (
291 StreamGenerator[List[AggregationResult]]
292 | AsyncStreamGenerator[List[AggregationResult]]
293 ):
294 """Runs the aggregation query.
295
296 This sends a``RunAggregationQuery`` RPC and returns a generator in the stream of ``RunAggregationQueryResponse`` messages.
297
298 Args:
299 transaction
300 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
301 An existing transaction that this query will run in.
302 retry (Optional[google.api_core.retry.Retry]): Designation of what
303 errors, if any, should be retried. Defaults to a
304 system-specified policy.
305 timeout (Optinal[float]): The timeout for this request. Defaults
306 to a system-specified value.
307 explain_options
308 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
309 Options to enable query profiling for this query. When set,
310 explain_metrics will be available on the returned generator.
311 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
312 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
313 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
314 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
315
316 Returns:
317 StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]:
318 A generator of the query results.
319 """