1# Copyright 2017 Google LLC All rights reserved. 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#     http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15"""Classes for representing collections for the Google Cloud Firestore API.""" 
    16from __future__ import annotations 
    17 
    18import random 
    19 
    20from typing import ( 
    21    TYPE_CHECKING, 
    22    Any, 
    23    AsyncGenerator, 
    24    AsyncIterator, 
    25    Coroutine, 
    26    Generator, 
    27    Generic, 
    28    Iterable, 
    29    Sequence, 
    30    Tuple, 
    31    Union, 
    32    Optional, 
    33) 
    34 
    35from google.api_core import retry as retries 
    36 
    37from google.cloud.firestore_v1 import _helpers 
    38from google.cloud.firestore_v1.base_query import QueryType 
    39 
    40if TYPE_CHECKING:  # pragma: NO COVER 
    41    # Types needed only for Type Hints 
    42    from google.cloud.firestore_v1.base_aggregation import BaseAggregationQuery 
    43    from google.cloud.firestore_v1.base_document import DocumentSnapshot 
    44    from google.cloud.firestore_v1.base_vector_query import ( 
    45        BaseVectorQuery, 
    46        DistanceMeasure, 
    47    ) 
    48    from google.cloud.firestore_v1.async_document import AsyncDocumentReference 
    49    from google.cloud.firestore_v1.document import DocumentReference 
    50    from google.cloud.firestore_v1.field_path import FieldPath 
    51    from google.cloud.firestore_v1.query_profile import ExplainOptions 
    52    from google.cloud.firestore_v1.query_results import QueryResultsList 
    53    from google.cloud.firestore_v1.stream_generator import StreamGenerator 
    54    from google.cloud.firestore_v1.transaction import Transaction 
    55    from google.cloud.firestore_v1.vector import Vector 
    56    from google.cloud.firestore_v1.vector_query import VectorQuery 
    57 
    58    import datetime 
    59 
    60_AUTO_ID_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789" 
    61 
    62 
    63class BaseCollectionReference(Generic[QueryType]): 
    64    """A reference to a collection in a Firestore database. 
    65 
    66    The collection may already exist or this class can facilitate creation 
    67    of documents within the collection. 
    68 
    69    Args: 
    70        path (Tuple[str, ...]): The components in the collection path. 
    71            This is a series of strings representing each collection and 
    72            sub-collection ID, as well as the document IDs for any documents 
    73            that contain a sub-collection. 
    74        kwargs (dict): The keyword arguments for the constructor. The only 
    75            supported keyword is ``client`` and it must be a 
    76            :class:`~google.cloud.firestore_v1.client.Client` if provided. It 
    77            represents the client that created this collection reference. 
    78 
    79    Raises: 
    80        ValueError: if 
    81 
    82            * the ``path`` is empty 
    83            * there are an even number of elements 
    84            * a collection ID in ``path`` is not a string 
    85            * a document ID in ``path`` is not a string 
    86        TypeError: If a keyword other than ``client`` is used. 
    87    """ 
    88 
    89    def __init__(self, *path, **kwargs) -> None: 
    90        _helpers.verify_path(path, is_collection=True) 
    91        self._path = path 
    92        self._client = kwargs.pop("client", None) 
    93        if kwargs: 
    94            raise TypeError( 
    95                "Received unexpected arguments", kwargs, "Only `client` is supported" 
    96            ) 
    97 
    98    def __eq__(self, other): 
    99        if not isinstance(other, self.__class__): 
    100            return NotImplemented 
    101        return self._path == other._path and self._client == other._client 
    102 
    103    @property 
    104    def id(self): 
    105        """The collection identifier. 
    106 
    107        Returns: 
    108            str: The last component of the path. 
    109        """ 
    110        return self._path[-1] 
    111 
    112    @property 
    113    def parent(self): 
    114        """Document that owns the current collection. 
    115 
    116        Returns: 
    117            Optional[:class:`~google.cloud.firestore_v1.document.DocumentReference`]: 
    118            The parent document, if the current collection is not a 
    119            top-level collection. 
    120        """ 
    121        if len(self._path) == 1: 
    122            return None 
    123        else: 
    124            parent_path = self._path[:-1] 
    125        return self._client.document(*parent_path) 
    126 
    127    def _query(self) -> QueryType: 
    128        raise NotImplementedError 
    129 
    130    def _aggregation_query(self) -> BaseAggregationQuery: 
    131        raise NotImplementedError 
    132 
    133    def _vector_query(self) -> BaseVectorQuery: 
    134        raise NotImplementedError 
    135 
    136    def document(self, document_id: Optional[str] = None): 
    137        """Create a sub-document underneath the current collection. 
    138 
    139        Args: 
    140            document_id (Optional[str]): The document identifier 
    141                within the current collection. If not provided, will default 
    142                to a random 20 character string composed of digits, 
    143                uppercase and lowercase and letters. 
    144 
    145        Returns: 
    146            :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: 
    147            The child document. 
    148        """ 
    149        if document_id is None: 
    150            document_id = _auto_id() 
    151 
    152        # Append `self._path` and the passed document's ID as long as the first 
    153        # element in the path is not an empty string, which comes from setting the 
    154        # parent to "" for recursive queries. 
    155        child_path = self._path + (document_id,) if self._path[0] else (document_id,) 
    156        return self._client.document(*child_path) 
    157 
    158    def _parent_info(self) -> Tuple[Any, str]: 
    159        """Get fully-qualified parent path and prefix for this collection. 
    160 
    161        Returns: 
    162            Tuple[str, str]: Pair of 
    163 
    164            * the fully-qualified (with database and project) path to the 
    165              parent of this collection (will either be the database path 
    166              or a document path). 
    167            * the prefix to a document in this collection. 
    168        """ 
    169        parent_doc = self.parent 
    170        if parent_doc is None: 
    171            parent_path = _helpers.DOCUMENT_PATH_DELIMITER.join( 
    172                (self._client._database_string, "documents") 
    173            ) 
    174        else: 
    175            parent_path = parent_doc._document_path 
    176 
    177        expected_prefix = _helpers.DOCUMENT_PATH_DELIMITER.join((parent_path, self.id)) 
    178        return parent_path, expected_prefix 
    179 
    180    def _prep_add( 
    181        self, 
    182        document_data: dict, 
    183        document_id: Optional[str] = None, 
    184        retry: retries.Retry | retries.AsyncRetry | object | None = None, 
    185        timeout: Optional[float] = None, 
    186    ): 
    187        """Shared setup for async / sync :method:`add`""" 
    188        if document_id is None: 
    189            document_id = _auto_id() 
    190 
    191        document_ref = self.document(document_id) 
    192        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    193 
    194        return document_ref, kwargs 
    195 
    196    def add( 
    197        self, 
    198        document_data: dict, 
    199        document_id: Optional[str] = None, 
    200        retry: retries.Retry | retries.AsyncRetry | object | None = None, 
    201        timeout: Optional[float] = None, 
    202    ) -> Union[Tuple[Any, Any], Coroutine[Any, Any, Tuple[Any, Any]]]: 
    203        raise NotImplementedError 
    204 
    205    def _prep_list_documents( 
    206        self, 
    207        page_size: Optional[int] = None, 
    208        retry: retries.Retry | retries.AsyncRetry | object | None = None, 
    209        timeout: Optional[float] = None, 
    210        read_time: Optional[datetime.datetime] = None, 
    211    ) -> Tuple[dict, dict]: 
    212        """Shared setup for async / sync :method:`list_documents`""" 
    213        parent, _ = self._parent_info() 
    214        request = { 
    215            "parent": parent, 
    216            "collection_id": self.id, 
    217            "page_size": page_size, 
    218            "show_missing": True, 
    219            # list_documents returns an iterator of document references, which do not 
    220            # include any fields. To save on data transfer, we can set a field_path mask 
    221            # to include no fields 
    222            "mask": {"field_paths": None}, 
    223        } 
    224        if read_time is not None: 
    225            request["read_time"] = read_time 
    226        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    227 
    228        return request, kwargs 
    229 
    230    def list_documents( 
    231        self, 
    232        page_size: Optional[int] = None, 
    233        retry: retries.Retry | retries.AsyncRetry | object | None = None, 
    234        timeout: Optional[float] = None, 
    235        *, 
    236        read_time: Optional[datetime.datetime] = None, 
    237    ) -> Union[ 
    238        Generator[DocumentReference, Any, Any], 
    239        AsyncGenerator[AsyncDocumentReference, Any], 
    240    ]: 
    241        raise NotImplementedError 
    242 
    243    def recursive(self) -> QueryType: 
    244        return self._query().recursive() 
    245 
    246    def select(self, field_paths: Iterable[str]) -> QueryType: 
    247        """Create a "select" query with this collection as parent. 
    248 
    249        See 
    250        :meth:`~google.cloud.firestore_v1.query.Query.select` for 
    251        more information on this method. 
    252 
    253        Args: 
    254            field_paths (Iterable[str, ...]): An iterable of field paths 
    255                (``.``-delimited list of field names) to use as a projection 
    256                of document fields in the query results. 
    257 
    258        Returns: 
    259            :class:`~google.cloud.firestore_v1.query.Query`: 
    260            A "projected" query. 
    261        """ 
    262        query = self._query() 
    263        return query.select(field_paths) 
    264 
    265    def where( 
    266        self, 
    267        field_path: Optional[str] = None, 
    268        op_string: Optional[str] = None, 
    269        value=None, 
    270        *, 
    271        filter=None, 
    272    ) -> QueryType: 
    273        """Create a "where" query with this collection as parent. 
    274 
    275        See 
    276        :meth:`~google.cloud.firestore_v1.query.Query.where` for 
    277        more information on this method. 
    278 
    279        Args: 
    280            field_path (str): A field path (``.``-delimited list of 
    281                field names) for the field to filter on. Optional. 
    282            op_string (str): A comparison operation in the form of a string. 
    283                Acceptable values are ``<``, ``<=``, ``==``, ``>=``, ``>``, 
    284                and ``in``. Optional. 
    285            value (Any): The value to compare the field against in the filter. 
    286                If ``value`` is :data:`None` or a NaN, then ``==`` is the only 
    287                allowed operation.  If ``op_string`` is ``in``, ``value`` 
    288                must be a sequence of values. Optional. 
    289            filter (class:`~google.cloud.firestore_v1.base_query.BaseFilter`): an instance of a Filter. 
    290                Either a FieldFilter or a CompositeFilter. 
    291        Returns: 
    292            :class:`~google.cloud.firestore_v1.query.Query`: 
    293            A filtered query. 
    294        Raises: 
    295            ValueError, if both the positional arguments (field_path, op_string, value) 
    296                and the filter keyword argument are passed at the same time. 
    297        """ 
    298        query = self._query() 
    299        if field_path and op_string: 
    300            if filter is not None: 
    301                raise ValueError( 
    302                    "Can't pass in both the positional arguments and 'filter' at the same time" 
    303                ) 
    304            if field_path == "__name__" and op_string == "in": 
    305                wrapped_names = [] 
    306 
    307                for name in value: 
    308                    if isinstance(name, str): 
    309                        name = self.document(name) 
    310 
    311                    wrapped_names.append(name) 
    312 
    313                value = wrapped_names 
    314            return query.where(field_path, op_string, value) 
    315        else: 
    316            return query.where(filter=filter) 
    317 
    318    def order_by(self, field_path: str, **kwargs) -> QueryType: 
    319        """Create an "order by" query with this collection as parent. 
    320 
    321        See 
    322        :meth:`~google.cloud.firestore_v1.query.Query.order_by` for 
    323        more information on this method. 
    324 
    325        Args: 
    326            field_path (str): A field path (``.``-delimited list of 
    327                field names) on which to order the query results. 
    328            kwargs (Dict[str, Any]): The keyword arguments to pass along 
    329                to the query. The only supported keyword is ``direction``, 
    330                see :meth:`~google.cloud.firestore_v1.query.Query.order_by` 
    331                for more information. 
    332 
    333        Returns: 
    334            :class:`~google.cloud.firestore_v1.query.Query`: 
    335            An "order by" query. 
    336        """ 
    337        query = self._query() 
    338        return query.order_by(field_path, **kwargs) 
    339 
    340    def limit(self, count: int) -> QueryType: 
    341        """Create a limited query with this collection as parent. 
    342 
    343        .. note:: 
    344           `limit` and `limit_to_last` are mutually exclusive. 
    345           Setting `limit` will drop previously set `limit_to_last`. 
    346 
    347        See 
    348        :meth:`~google.cloud.firestore_v1.query.Query.limit` for 
    349        more information on this method. 
    350 
    351        Args: 
    352            count (int): Maximum number of documents to return that match 
    353                the query. 
    354 
    355        Returns: 
    356            :class:`~google.cloud.firestore_v1.query.Query`: 
    357            A limited query. 
    358        """ 
    359        query = self._query() 
    360        return query.limit(count) 
    361 
    362    def limit_to_last(self, count: int): 
    363        """Create a limited to last query with this collection as parent. 
    364 
    365        .. note:: 
    366           `limit` and `limit_to_last` are mutually exclusive. 
    367           Setting `limit_to_last` will drop previously set `limit`. 
    368 
    369        See 
    370        :meth:`~google.cloud.firestore_v1.query.Query.limit_to_last` 
    371        for more information on this method. 
    372 
    373        Args: 
    374            count (int): Maximum number of documents to return that 
    375                match the query. 
    376        Returns: 
    377            :class:`~google.cloud.firestore_v1.query.Query`: 
    378            A limited to last query. 
    379        """ 
    380        query = self._query() 
    381        return query.limit_to_last(count) 
    382 
    383    def offset(self, num_to_skip: int) -> QueryType: 
    384        """Skip to an offset in a query with this collection as parent. 
    385 
    386        See 
    387        :meth:`~google.cloud.firestore_v1.query.Query.offset` for 
    388        more information on this method. 
    389 
    390        Args: 
    391            num_to_skip (int): The number of results to skip at the beginning 
    392                of query results. (Must be non-negative.) 
    393 
    394        Returns: 
    395            :class:`~google.cloud.firestore_v1.query.Query`: 
    396            An offset query. 
    397        """ 
    398        query = self._query() 
    399        return query.offset(num_to_skip) 
    400 
    401    def start_at( 
    402        self, document_fields: Union[DocumentSnapshot, dict, list, tuple] 
    403    ) -> QueryType: 
    404        """Start query at a cursor with this collection as parent. 
    405 
    406        See 
    407        :meth:`~google.cloud.firestore_v1.query.Query.start_at` for 
    408        more information on this method. 
    409 
    410        Args: 
    411            document_fields (Union[:class:`~google.cloud.firestore_v1.\ 
    412                document.DocumentSnapshot`, dict, list, tuple]): 
    413                A document snapshot or a dictionary/list/tuple of fields 
    414                representing a query results cursor. A cursor is a collection 
    415                of values that represent a position in a query result set. 
    416 
    417        Returns: 
    418            :class:`~google.cloud.firestore_v1.query.Query`: 
    419            A query with cursor. 
    420        """ 
    421        query = self._query() 
    422        return query.start_at(document_fields) 
    423 
    424    def start_after( 
    425        self, document_fields: Union[DocumentSnapshot, dict, list, tuple] 
    426    ) -> QueryType: 
    427        """Start query after a cursor with this collection as parent. 
    428 
    429        See 
    430        :meth:`~google.cloud.firestore_v1.query.Query.start_after` for 
    431        more information on this method. 
    432 
    433        Args: 
    434            document_fields (Union[:class:`~google.cloud.firestore_v1.\ 
    435                document.DocumentSnapshot`, dict, list, tuple]): 
    436                A document snapshot or a dictionary/list/tuple of fields 
    437                representing a query results cursor. A cursor is a collection 
    438                of values that represent a position in a query result set. 
    439 
    440        Returns: 
    441            :class:`~google.cloud.firestore_v1.query.Query`: 
    442            A query with cursor. 
    443        """ 
    444        query = self._query() 
    445        return query.start_after(document_fields) 
    446 
    447    def end_before( 
    448        self, document_fields: Union[DocumentSnapshot, dict, list, tuple] 
    449    ) -> QueryType: 
    450        """End query before a cursor with this collection as parent. 
    451 
    452        See 
    453        :meth:`~google.cloud.firestore_v1.query.Query.end_before` for 
    454        more information on this method. 
    455 
    456        Args: 
    457            document_fields (Union[:class:`~google.cloud.firestore_v1.\ 
    458                document.DocumentSnapshot`, dict, list, tuple]): 
    459                A document snapshot or a dictionary/list/tuple of fields 
    460                representing a query results cursor. A cursor is a collection 
    461                of values that represent a position in a query result set. 
    462 
    463        Returns: 
    464            :class:`~google.cloud.firestore_v1.query.Query`: 
    465            A query with cursor. 
    466        """ 
    467        query = self._query() 
    468        return query.end_before(document_fields) 
    469 
    470    def end_at( 
    471        self, document_fields: Union[DocumentSnapshot, dict, list, tuple] 
    472    ) -> QueryType: 
    473        """End query at a cursor with this collection as parent. 
    474 
    475        See 
    476        :meth:`~google.cloud.firestore_v1.query.Query.end_at` for 
    477        more information on this method. 
    478 
    479        Args: 
    480            document_fields (Union[:class:`~google.cloud.firestore_v1.\ 
    481                document.DocumentSnapshot`, dict, list, tuple]): 
    482                A document snapshot or a dictionary/list/tuple of fields 
    483                representing a query results cursor. A cursor is a collection 
    484                of values that represent a position in a query result set. 
    485 
    486        Returns: 
    487            :class:`~google.cloud.firestore_v1.query.Query`: 
    488            A query with cursor. 
    489        """ 
    490        query = self._query() 
    491        return query.end_at(document_fields) 
    492 
    493    def _prep_get_or_stream( 
    494        self, 
    495        retry: retries.Retry | retries.AsyncRetry | object | None = None, 
    496        timeout: Optional[float] = None, 
    497    ) -> Tuple[Any, dict]: 
    498        """Shared setup for async / sync :meth:`get` / :meth:`stream`""" 
    499        query = self._query() 
    500        kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout) 
    501 
    502        return query, kwargs 
    503 
    504    def get( 
    505        self, 
    506        transaction: Optional[Transaction] = None, 
    507        retry: retries.Retry | retries.AsyncRetry | object | None = None, 
    508        timeout: Optional[float] = None, 
    509        *, 
    510        explain_options: Optional[ExplainOptions] = None, 
    511        read_time: Optional[datetime.datetime] = None, 
    512    ) -> ( 
    513        QueryResultsList[DocumentSnapshot] 
    514        | Coroutine[Any, Any, QueryResultsList[DocumentSnapshot]] 
    515    ): 
    516        raise NotImplementedError 
    517 
    518    def stream( 
    519        self, 
    520        transaction: Optional[Transaction] = None, 
    521        retry: retries.Retry | retries.AsyncRetry | object | None = None, 
    522        timeout: Optional[float] = None, 
    523        *, 
    524        explain_options: Optional[ExplainOptions] = None, 
    525        read_time: Optional[datetime.datetime] = None, 
    526    ) -> StreamGenerator[DocumentSnapshot] | AsyncIterator[DocumentSnapshot]: 
    527        raise NotImplementedError 
    528 
    529    def on_snapshot(self, callback): 
    530        raise NotImplementedError 
    531 
    532    def count(self, alias=None): 
    533        """ 
    534        Adds a count over the nested query. 
    535 
    536        :type alias: str 
    537        :param alias: (Optional) The alias for the count 
    538        """ 
    539        return self._aggregation_query().count(alias=alias) 
    540 
    541    def sum(self, field_ref: str | FieldPath, alias=None): 
    542        """ 
    543        Adds a sum over the nested query. 
    544 
    545        :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] 
    546        :param field_ref: The field to aggregate across. 
    547 
    548        :type alias: Optional[str] 
    549        :param alias: Optional name of the field to store the result of the aggregation into. 
    550            If not provided, Firestore will pick a default name following the format field_<incremental_id++>. 
    551 
    552        """ 
    553        return self._aggregation_query().sum(field_ref, alias=alias) 
    554 
    555    def avg(self, field_ref: str | FieldPath, alias=None): 
    556        """ 
    557        Adds an avg over the nested query. 
    558 
    559        :type field_ref: Union[str, google.cloud.firestore_v1.field_path.FieldPath] 
    560        :param field_ref: The field to aggregate across. 
    561 
    562        :type alias: Optional[str] 
    563        :param alias: Optional name of the field to store the result of the aggregation into. 
    564            If not provided, Firestore will pick a default name following the format field_<incremental_id++>. 
    565        """ 
    566        return self._aggregation_query().avg(field_ref, alias=alias) 
    567 
    568    def find_nearest( 
    569        self, 
    570        vector_field: str, 
    571        query_vector: Union[Vector, Sequence[float]], 
    572        limit: int, 
    573        distance_measure: DistanceMeasure, 
    574        *, 
    575        distance_result_field: Optional[str] = None, 
    576        distance_threshold: Optional[float] = None, 
    577    ) -> VectorQuery: 
    578        """ 
    579        Finds the closest vector embeddings to the given query vector. 
    580 
    581        Args: 
    582            vector_field (str): An indexed vector field to search upon. Only documents which contain 
    583                vectors whose dimensionality match the query_vector can be returned. 
    584            query_vector(Union[Vector, Sequence[float]]): The query vector that we are searching on. Must be a vector of no more 
    585                than 2048 dimensions. 
    586            limit (int): The number of nearest neighbors to return. Must be a positive integer of no more than 1000. 
    587            distance_measure (:class:`DistanceMeasure`): The Distance Measure to use. 
    588            distance_result_field (Optional[str]): 
    589                Name of the field to output the result of the vector distance calculation 
    590            distance_threshold (Optional[float]): 
    591                A threshold for which no less similar documents will be returned. 
    592 
    593        Returns: 
    594            :class`~firestore_v1.vector_query.VectorQuery`: the vector query. 
    595        """ 
    596        return self._vector_query().find_nearest( 
    597            vector_field, 
    598            query_vector, 
    599            limit, 
    600            distance_measure, 
    601            distance_result_field=distance_result_field, 
    602            distance_threshold=distance_threshold, 
    603        ) 
    604 
    605 
    606def _auto_id() -> str: 
    607    """Generate a "random" automatically generated ID. 
    608 
    609    Returns: 
    610        str: A 20 character string composed of digits, uppercase and 
    611        lowercase and letters. 
    612    """ 
    613 
    614    return "".join(random.choice(_AUTO_ID_CHARS) for _ in range(20)) 
    615 
    616 
    617def _item_to_document_ref(collection_reference, item): 
    618    """Convert Document resource to document ref. 
    619 
    620    Args: 
    621        collection_reference (google.api_core.page_iterator.GRPCIterator): 
    622            iterator response 
    623        item (dict): document resource 
    624 
    625    Returns: 
    626            :class:`~google.cloud.firestore_v1.base_document.BaseDocumentReference`: 
    627            The child document 
    628    """ 
    629    document_id = item.name.split(_helpers.DOCUMENT_PATH_DELIMITER)[-1] 
    630    return collection_reference.document(document_id)