1# Copyright 2023 Google LLC All rights reserved. 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Classes for representing aggregation queries for the Google Cloud Firestore API. 
    16 
    17A :class:`~google.cloud.firestore_v1.aggregation.AggregationQuery` can be created directly from 
    18a :class:`~google.cloud.firestore_v1.collection.Collection` and that can be 
    19a more common way to create an aggregation query than direct usage of the constructor. 
    20""" 
    21from __future__ import annotations 
    22 
    23import abc 
    24 
    25from abc import ABC 
    26from typing import TYPE_CHECKING, Any, Coroutine, List, Optional, Tuple, Union 
    27 
    28from google.api_core import gapic_v1 
    29from google.api_core import retry as retries 
    30 
    31from google.cloud.firestore_v1 import _helpers 
    32from google.cloud.firestore_v1.field_path import FieldPath 
    33from google.cloud.firestore_v1.types import ( 
    34    StructuredAggregationQuery, 
    35) 
    36 
    37# Types needed only for Type Hints 
    38if TYPE_CHECKING:  # pragma: NO COVER 
    39    from google.cloud.firestore_v1 import transaction 
    40    from google.cloud.firestore_v1.async_stream_generator import AsyncStreamGenerator 
    41    from google.cloud.firestore_v1.query_profile import ExplainOptions 
    42    from google.cloud.firestore_v1.query_results import QueryResultsList 
    43    from google.cloud.firestore_v1.stream_generator import ( 
    44        StreamGenerator, 
    45    ) 
    46 
    47    import datetime 
    48 
    49 
    50class AggregationResult(object): 
    51    """ 
    52    A class representing result from Aggregation Query 
    53    :type alias: str 
    54    :param alias: The alias for the aggregation. 
    55    :type value: int 
    56    :param value: The resulting value from the aggregation. 
    57    :type read_time: 
    58    :param value: The resulting read_time 
    59    """ 
    60 
    61    def __init__(self, alias: str, value: float, read_time=None): 
    62        self.alias = alias 
    63        self.value = value 
    64        self.read_time = read_time 
    65 
    66    def __repr__(self): 
    67        return f"<Aggregation alias={self.alias}, value={self.value}, readtime={self.read_time}>" 
    68 
    69 
    70class BaseAggregation(ABC): 
    71    def __init__(self, alias: str | None = None): 
    72        self.alias = alias 
    73 
    74    @abc.abstractmethod 
    75    def _to_protobuf(self): 
    76        """Convert this instance to the protobuf representation""" 
    77 
    78 
    79class CountAggregation(BaseAggregation): 
    80    def __init__(self, alias: str | None = None): 
    81        super(CountAggregation, self).__init__(alias=alias) 
    82 
    83    def _to_protobuf(self): 
    84        """Convert this instance to the protobuf representation""" 
    85        aggregation_pb = StructuredAggregationQuery.Aggregation() 
    86        if self.alias: 
    87            aggregation_pb.alias = self.alias 
    88        aggregation_pb.count = StructuredAggregationQuery.Aggregation.Count() 
    89        return aggregation_pb 
    90 
    91 
    92class SumAggregation(BaseAggregation): 
    93    def __init__(self, field_ref: str | FieldPath, alias: str | None = None): 
    94        # convert field path to string if needed 
    95        field_str = ( 
    96            field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref 
    97        ) 
    98        self.field_ref: str = field_str 
    99        super(SumAggregation, self).__init__(alias=alias) 
    100 
    101    def _to_protobuf(self): 
    102        """Convert this instance to the protobuf representation""" 
    103        aggregation_pb = StructuredAggregationQuery.Aggregation() 
    104        if self.alias: 
    105            aggregation_pb.alias = self.alias 
    106        aggregation_pb.sum = StructuredAggregationQuery.Aggregation.Sum() 
    107        aggregation_pb.sum.field.field_path = self.field_ref 
    108        return aggregation_pb 
    109 
    110 
    111class AvgAggregation(BaseAggregation): 
    112    def __init__(self, field_ref: str | FieldPath, alias: str | None = None): 
    113        # convert field path to string if needed 
    114        field_str = ( 
    115            field_ref.to_api_repr() if isinstance(field_ref, FieldPath) else field_ref 
    116        ) 
    117        self.field_ref: str = field_str 
    118        super(AvgAggregation, self).__init__(alias=alias) 
    119 
    120    def _to_protobuf(self): 
    121        """Convert this instance to the protobuf representation""" 
    122        aggregation_pb = StructuredAggregationQuery.Aggregation() 
    123        if self.alias: 
    124            aggregation_pb.alias = self.alias 
    125        aggregation_pb.avg = StructuredAggregationQuery.Aggregation.Avg() 
    126        aggregation_pb.avg.field.field_path = self.field_ref 
    127        return aggregation_pb 
    128 
    129 
    130def _query_response_to_result( 
    131    response_pb, 
    132) -> List[AggregationResult]: 
    133    results = [ 
    134        AggregationResult( 
    135            alias=key, 
    136            value=response_pb.result.aggregate_fields[key].integer_value 
    137            or response_pb.result.aggregate_fields[key].double_value, 
    138            read_time=response_pb.read_time, 
    139        ) 
    140        for key in response_pb.result.aggregate_fields.pb.keys() 
    141    ] 
    142 
    143    return results 
    144 
    145 
    146class BaseAggregationQuery(ABC): 
    147    """Represents an aggregation query to the Firestore API.""" 
    148 
    149    def __init__(self, nested_query, alias: str | None = None) -> None: 
    150        self._nested_query = nested_query 
    151        self._alias = alias 
    152        self._collection_ref = nested_query._parent 
    153        self._aggregations: List[BaseAggregation] = [] 
    154 
    155    @property 
    156    def _client(self): 
    157        return self._collection_ref._client 
    158 
    159    def count(self, alias: str | None = None): 
    160        """ 
    161        Adds a count over the nested query 
    162        """ 
    163        count_aggregation = CountAggregation(alias=alias) 
    164        self._aggregations.append(count_aggregation) 
    165        return self 
    166 
    167    def sum(self, field_ref: str | FieldPath, alias: str | None = None): 
    168        """ 
    169        Adds a sum over the nested query 
    170        """ 
    171        sum_aggregation = SumAggregation(field_ref, alias=alias) 
    172        self._aggregations.append(sum_aggregation) 
    173        return self 
    174 
    175    def avg(self, field_ref: str | FieldPath, alias: str | None = None): 
    176        """ 
    177        Adds an avg over the nested query 
    178        """ 
    179        avg_aggregation = AvgAggregation(field_ref, alias=alias) 
    180        self._aggregations.append(avg_aggregation) 
    181        return self 
    182 
    183    def add_aggregation(self, aggregation: BaseAggregation) -> None: 
    184        """ 
    185        Adds an aggregation operation to the nested query 
    186 
    187        :type aggregation: :class:`google.cloud.firestore_v1.aggregation.BaseAggregation` 
    188        :param aggregation: An aggregation operation, e.g. a CountAggregation 
    189        """ 
    190        self._aggregations.append(aggregation) 
    191 
    192    def add_aggregations(self, aggregations: List[BaseAggregation]) -> None: 
    193        """ 
    194        Adds a list of aggregations to the nested query 
    195 
    196        :type aggregations: list 
    197        :param aggregations: a list of aggregation operations 
    198        """ 
    199        self._aggregations.extend(aggregations) 
    200 
    201    def _to_protobuf(self) -> StructuredAggregationQuery: 
    202        pb = StructuredAggregationQuery() 
    203        pb.structured_query = self._nested_query._to_protobuf() 
    204 
    205        for aggregation in self._aggregations: 
    206            aggregation_pb = aggregation._to_protobuf() 
    207            pb.aggregations.append(aggregation_pb) 
    208        return pb 
    209 
    210    def _prep_stream( 
    211        self, 
    212        transaction=None, 
    213        retry: Union[retries.Retry, retries.AsyncRetry, None, object] = None, 
    214        timeout: float | None = None, 
    215        explain_options: Optional[ExplainOptions] = None, 
    216        read_time: Optional[datetime.datetime] = None, 
    217    ) -> Tuple[dict, dict]: 
    218        parent_path, expected_prefix = self._collection_ref._parent_info() 
    219        request = { 
    220            "parent": parent_path, 
    221            "structured_aggregation_query": self._to_protobuf(), 
    222            "transaction": _helpers.get_transaction_id(transaction), 
    223        } 
    224        if explain_options: 
    225            request["explain_options"] = explain_options._to_dict() 
    226        if read_time is not None: 
    227            request["read_time"] = read_time 
    228        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    229 
    230        return request, kwargs 
    231 
    232    @abc.abstractmethod 
    233    def get( 
    234        self, 
    235        transaction=None, 
    236        retry: Union[ 
    237            retries.Retry, retries.AsyncRetry, None, object 
    238        ] = gapic_v1.method.DEFAULT, 
    239        timeout: float | None = None, 
    240        *, 
    241        explain_options: Optional[ExplainOptions] = None, 
    242        read_time: Optional[datetime.datetime] = None, 
    243    ) -> ( 
    244        QueryResultsList[AggregationResult] 
    245        | Coroutine[Any, Any, List[List[AggregationResult]]] 
    246    ): 
    247        """Runs the aggregation query. 
    248 
    249        This sends a ``RunAggregationQuery`` RPC and returns a list of 
    250        aggregation results in the stream of ``RunAggregationQueryResponse`` 
    251        messages. 
    252 
    253        Args: 
    254            transaction 
    255                (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 
    256                An existing transaction that this query will run in. 
    257                If a ``transaction`` is used and it already has write operations 
    258                added, this method cannot be used (i.e. read-after-write is not 
    259                allowed). 
    260            retry (google.api_core.retry.Retry): Designation of what errors, if any, 
    261                should be retried.  Defaults to a system-specified policy. 
    262            timeout (float): The timeout for this request.  Defaults to a 
    263                system-specified value. 
    264            explain_options 
    265                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    266                Options to enable query profiling for this query. When set, 
    267                explain_metrics will be available on the returned generator. 
    268            read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 
    269                time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 
    270                is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 
    271                timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 
    272 
    273        Returns: 
    274            (QueryResultsList[List[AggregationResult]] | Coroutine[Any, Any, List[List[AggregationResult]]]): 
    275            The aggregation query results. 
    276        """ 
    277 
    278    @abc.abstractmethod 
    279    def stream( 
    280        self, 
    281        transaction: Optional[transaction.Transaction] = None, 
    282        retry: retries.Retry 
    283        | retries.AsyncRetry 
    284        | object 
    285        | None = gapic_v1.method.DEFAULT, 
    286        timeout: Optional[float] = None, 
    287        *, 
    288        explain_options: Optional[ExplainOptions] = None, 
    289        read_time: Optional[datetime.datetime] = None, 
    290    ) -> ( 
    291        StreamGenerator[List[AggregationResult]] 
    292        | AsyncStreamGenerator[List[AggregationResult]] 
    293    ): 
    294        """Runs the aggregation query. 
    295 
    296        This sends a``RunAggregationQuery`` RPC and returns a generator in the stream of ``RunAggregationQueryResponse`` messages. 
    297 
    298        Args: 
    299            transaction 
    300                (Optional[:class:`~google.cloud.firestore_v1.transaction.Transaction`]): 
    301                An existing transaction that this query will run in. 
    302            retry (Optional[google.api_core.retry.Retry]): Designation of what 
    303                errors, if any, should be retried.  Defaults to a 
    304                system-specified policy. 
    305            timeout (Optinal[float]): The timeout for this request.  Defaults 
    306                to a system-specified value. 
    307            explain_options 
    308                (Optional[:class:`~google.cloud.firestore_v1.query_profile.ExplainOptions`]): 
    309                Options to enable query profiling for this query. When set, 
    310                explain_metrics will be available on the returned generator. 
    311            read_time (Optional[datetime.datetime]): If set, reads documents as they were at the given 
    312                time. This must be a timestamp within the past one hour, or if Point-in-Time Recovery 
    313                is enabled, can additionally be a whole minute timestamp within the past 7 days. If no 
    314                timezone is specified in the :class:`datetime.datetime` object, it is assumed to be UTC. 
    315 
    316        Returns: 
    317            StreamGenerator[List[AggregationResult]] | AsyncStreamGenerator[List[AggregationResult]]: 
    318            A generator of the query results. 
    319        """