1# Copyright 2023 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Classes for representing aggregation queries for the Google Cloud Firestore API.
16
17A :class:`~google.cloud.firestore_v1.aggregation.AggregationQuery` can be created directly from
18a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be
19a more common way to create an aggregation query than direct usage of the constructor.
20"""
21from __future__ import annotations
22
23import abc
24import itertools
25
26from abc import ABC
27from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union, Iterable
28
29from google.api_core import gapic_v1
30from google.api_core import retry as retries
31
32from google.cloud.firestore_v1 import _helpers
33from google.cloud.firestore_v1.field_path import FieldPath
34from google.cloud.firestore_v1.types import (
35 StructuredAggregationQuery,
36)
37from google.cloud.firestore_v1.pipeline_expressions import AggregateFunction
38from google.cloud.firestore_v1.pipeline_expressions import Count
39from google.cloud.firestore_v1.pipeline_expressions import AliasedExpression
40from google.cloud.firestore_v1.pipeline_expressions import Field
41
42# Types needed only for Type Hints
43if TYPE_CHECKING: # pragma: NO COVER
44 from google.cloud.firestore_v1 import transaction
45 from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator
46 from google.cloud.firestore_v1.query_profile import ExplainOptions
47 from google.cloud.firestore_v1.query_results import QueryResultsList
48 from google.cloud.firestore_v1.stream_generator import (
49 StreamGenerator,
50 )
51 from google.cloud.firestore_v1.pipeline_source import PipelineSource
52
53 import datetime
54
55
56class AggregationResult(object):
57 """
58 A class representing result from Aggregation Query
59 :type alias: str
60 :param alias: The alias for the aggregation.
61 :type value: int
62 :param value: The resulting value from the aggregation.
63 :type read_time:
64 :param value: The resulting read_time
65 """
66
67 def __init__(self, alias: str, value: float, read_time=None):
68 self.alias = alias
69 self.value = value
70 self.read_time = read_time
71
72 def __repr__(self):
73 return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>"
74
75 def _to_dict(self):
76 return {self.alias: self.value}
77
78
79class BaseAggregation(ABC):
80 def __init__(self, alias: str | None = None):
81 self.alias = alias
82
83 @abc.abstractmethod
84 def _to_protobuf(self):
85 """Convert this instance to the protobuf representation"""
86
87 @abc.abstractmethod
88 def _to_pipeline_expr(
89 self, autoindexer: Iterable[int]
90 ) -> AliasedExpression[AggregateFunction]:
91 """
92 Convert this instance to a pipeline expression for use with pipeline.aggregate()
93
94 Args:
95 autoindexer: If an alias isn't supplied, one should be created with the format "field_n"
96 The autoindexer is an iterable that provides the `n` value to use for each expression
97 """
98
99 def _pipeline_alias(self, autoindexer):
100 """
101 Helper to build the alias for the pipeline expression
102 """
103 if self.alias is not None:
104 return self.alias
105 else:
106 return f"field_{next(autoindexer)}"
107
108
109class CountAggregation(BaseAggregation):
110 def __init__(self, alias: str | None = None):
111 super(CountAggregation, self).__init__(alias=alias)
112
113 def _to_protobuf(self):
114 """Convert this instance to the protobuf representation"""
115 aggregation_pb = StructuredAggregationQuery.Aggregation()
116 if self.alias:
117 aggregation_pb.alias = self.alias
118 aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count()
119 return aggregation_pb
120
121 def _to_pipeline_expr(self, autoindexer: Iterable[int]):
122 return Count().as_(self._pipeline_alias(autoindexer))
123
124
125class SumAggregation(BaseAggregation):
126 def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
127 # convert field path to string if needed
128 field_str = (
129 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
130 )
131 self.field_ref: str = field_str
132 super(SumAggregation, self).__init__(alias=alias)
133
134 def _to_protobuf(self):
135 """Convert this instance to the protobuf representation"""
136 aggregation_pb = StructuredAggregationQuery.Aggregation()
137 if self.alias:
138 aggregation_pb.alias = self.alias
139 aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum()
140 aggregation_pb.sum.field.field_path = self.field_ref
141 return aggregation_pb
142
143 def _to_pipeline_expr(self, autoindexer: Iterable[int]):
144 return Field.of(self.field_ref).sum().as_(self._pipeline_alias(autoindexer))
145
146
147class AvgAggregation(BaseAggregation):
148 def __init__(self, field_ref: str | FieldPath, alias: str | None = None):
149 # convert field path to string if needed
150 field_str = (
151 field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref
152 )
153 self.field_ref: str = field_str
154 super(AvgAggregation, self).__init__(alias=alias)
155
156 def _to_protobuf(self):
157 """Convert this instance to the protobuf representation"""
158 aggregation_pb = StructuredAggregationQuery.Aggregation()
159 if self.alias:
160 aggregation_pb.alias = self.alias
161 aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg()
162 aggregation_pb.avg.field.field_path = self.field_ref
163 return aggregation_pb
164
165 def _to_pipeline_expr(self, autoindexer: Iterable[int]):
166 return Field.of(self.field_ref).average().as_(self._pipeline_alias(autoindexer))
167
168
169def _query_response_to_result(
170 response_pb,
171) -> List[AggregationResult]:
172 results = [
173 AggregationResult(
174 alias=key,
175 value=response_pb.result.aggregate_fields[key].integer_value
176 or response_pb.result.aggregate_fields[key].double_value,
177 read_time=response_pb.read_time,
178 )
179 for key in response_pb.result.aggregate_fields.pb.keys()
180 ]
181
182 return results
183
184
185class BaseAggregationQuery(ABC):
186 """Represents an aggregation query to the Firestore API."""
187
188 def __init__(self, nested_query, alias: str | None = None) -> None:
189 self._nested_query = nested_query
190 self._alias = alias
191 self._collection_ref = nested_query._parent
192 self._aggregations: List[BaseAggregation] = []
193
194 @property
195 def _client(self):
196 return self._collection_ref._client
197
198 def count(self, alias: str | None = None):
199 """
200 Adds a count over the nested query
201 """
202 count_aggregation = CountAggregation(alias=alias)
203 self._aggregations.append(count_aggregation)
204 return self
205
206 def sum(self, field_ref: str | FieldPath, alias: str | None = None):
207 """
208 Adds a sum over the nested query
209 """
210 sum_aggregation = SumAggregation(field_ref, alias=alias)
211 self._aggregations.append(sum_aggregation)
212 return self
213
214 def avg(self, field_ref: str | FieldPath, alias: str | None = None):
215 """
216 Adds an avg over the nested query
217 """
218 avg_aggregation = AvgAggregation(field_ref, alias=alias)
219 self._aggregations.append(avg_aggregation)
220 return self
221
222 def add_aggregation(self, aggregation: BaseAggregation) -> None:
223 """
224 Adds an aggregation operation to the nested query
225
226 :type aggregation: :class:`google.cloud.firestore_v1.aggregation.BaseAggregation`
227 :param aggregation: An aggregation operation, e.g. a CountAggregation
228 """
229 self._aggregations.append(aggregation)
230
231 def add_aggregations(self, aggregations: List[BaseAggregation]) -> None:
232 """
233 Adds a list of aggregations to the nested query
234
235 :type aggregations: list
236 :param aggregations: a list of aggregation operations
237 """
238 self._aggregations.extend(aggregations)
239
240 def _to_protobuf(self) -> StructuredAggregationQuery:
241 pb = StructuredAggregationQuery()
242 pb.structured_query = self._nested_query._to_protobuf()
243
244 for aggregation in self._aggregations:
245 aggregation_pb = aggregation._to_protobuf()
246 pb.aggregations.append(aggregation_pb)
247 return pb
248
249 def _prep_stream(
250 self,
251 transaction=None,
252 retry: Union[retries.Retry, retries.AsyncRetry, None, object] = None,
253 timeout: float | None = None,
254 explain_options: Optional[ExplainOptions] = None,
255 read_time: Optional[datetime.datetime] = None,
256 ) -> Tuple[dict, dict]:
257 parent_path, expected_prefix = self._collection_ref._parent_info()
258 request = {
259 "parent": parent_path,
260 "structured_aggregation_query": self._to_protobuf(),
261 "transaction": _helpers.get_transaction_id(transaction),
262 }
263 if explain_options:
264 request["explain_options"] = explain_options._to_dict()
265 if read_time is not None:
266 request["read_time"] = read_time
267 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
268
269 return request, kwargs
270
271 @abc.abstractmethod
272 def get(
273 self,
274 transaction=None,
275 retry: Union[
276 retries.Retry, retries.AsyncRetry, None, object
277 ] = gapic_v1.method.DEFAULT,
278 timeout: float | None = None,
279 *,
280 explain_options: Optional[ExplainOptions] = None,
281 read_time: Optional[datetime.datetime] = None,
282 ) -> (
283 QueryResultsList[AggregationResult]
284 | Coroutine[Any, Any, List[List[AggregationResult]]]
285 ):
286 """Runs the aggregation query.
287
288 This sends a ``RunAggregationQuery`` RPC and returns a list of
289 aggregation results in the stream of ``RunAggregationQueryResponse``
290 messages.
291
292 Args:
293 transaction
294 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
295 An existing transaction that this query will run in.
296 If a ``transaction`` is used and it already has write operations
297 added, this method cannot be used (i.e. read-after-write is not
298 allowed).
299 retry (google.api_core.retry.Retry): Designation of what errors, if any,
300 should be retried. Defaults to a system-specified policy.
301 timeout (float): The timeout for this request. Defaults to a
302 system-specified value.
303 explain_options
304 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
305 Options to enable query profiling for this query. When set,
306 explain_metrics will be available on the returned generator.
307 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
308 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
309 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
310 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
311
312 Returns:
313 (QueryResultsList[List[AggregationResult]] | Coroutine[Any, Any, List[List[AggregationResult]]]):
314 The aggregation query results.
315 """
316
317 @abc.abstractmethod
318 def stream(
319 self,
320 transaction: Optional[transaction.Transaction] = None,
321 retry: retries.Retry
322 | retries.AsyncRetry
323 | object
324 | None = gapic_v1.method.DEFAULT,
325 timeout: Optional[float] = None,
326 *,
327 explain_options: Optional[ExplainOptions] = None,
328 read_time: Optional[datetime.datetime] = None,
329 ) -> (
330 StreamGenerator[List[AggregationResult]]
331 | AsyncStreamGenerator[List[AggregationResult]]
332 ):
333 """Runs the aggregation query.
334
335 This sends a``RunAggregationQuery`` RPC and returns a generator in the stream of ``RunAggregationQueryResponse`` messages.
336
337 Args:
338 transaction
339 (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]):
340 An existing transaction that this query will run in.
341 retry (Optional[google.api_core.retry.Retry]): Designation of what
342 errors, if any, should be retried. Defaults to a
343 system-specified policy.
344 timeout (Optinal[float]): The timeout for this request. Defaults
345 to a system-specified value.
346 explain_options
347 (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]):
348 Options to enable query profiling for this query. When set,
349 explain_metrics will be available on the returned generator.
350 read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given
351 time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery
352 is enabled, can additionally be a whole minute timestamp within the past 7 days. If no
353 timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC.
354
355 Returns:
356 StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]:
357 A generator of the query results.
358 """
359
360 def _build_pipeline(self, source: "PipelineSource"):
361 """
362 Convert this query into a Pipeline
363
364 Args:
365 source: the PipelineSource to build the pipeline off of
366 Returns:
367 a Pipeline representing the query
368 """
369 # use autoindexer to keep track of which field number to use for un-aliased fields
370 autoindexer = itertools.count(start=1)
371 exprs = [a._to_pipeline_expr(autoindexer) for a in self._aggregations]
372 return self._nested_query._build_pipeline(source).aggregate(*exprs)