1# Copyright 2017 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Client for interacting with the Google Cloud Firestore API.
16
17This is the base from which all interactions with the API occur.
18
19In the hierarchy of API concepts
20
21* a :class:`~google.cloud.firestore_v1.client.Client` owns a
22 :class:`~google.cloud.firestore_v1.collection.CollectionReference`
23* a :class:`~google.cloud.firestore_v1.client.Client` owns a
24 :class:`~google.cloud.firestore_v1.document.DocumentReference`
25"""
26from __future__ import annotations
27
28import datetime
29import os
30from typing import (
31 Any,
32 AsyncGenerator,
33 Awaitable,
34 Generator,
35 Iterable,
36 List,
37 Optional,
38 Tuple,
39 Union,
40)
41
42import google.api_core.client_options
43import google.api_core.path_template
44import grpc # type: ignore
45from google.api_core import retry as retries
46from google.api_core.gapic_v1 import client_info
47from google.auth.credentials import AnonymousCredentials
48from google.cloud.client import ClientWithProject # type: ignore
49
50from google.cloud.firestore_v1 import __version__, _helpers, types
51from google.cloud.firestore_v1.base_batch import BaseWriteBatch
52
53# Types needed only for Type Hints
54from google.cloud.firestore_v1.base_collection import BaseCollectionReference
55from google.cloud.firestore_v1.base_document import (
56 BaseDocumentReference,
57 DocumentSnapshot,
58)
59from google.cloud.firestore_v1.base_query import BaseQuery
60from google.cloud.firestore_v1.base_transaction import MAX_ATTEMPTS, BaseTransaction
61from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions
62from google.cloud.firestore_v1.field_path import render_field_path
63from google.cloud.firestore_v1.services.firestore import client as firestore_client
64
65DEFAULT_DATABASE = "(default)"
66"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`."""
67_DEFAULT_EMULATOR_PROJECT = "google-cloud-firestore-emulator"
68_BAD_OPTION_ERR = (
69 "Exactly one of ``last_update_time`` or ``exists`` " "must be provided."
70)
71_BAD_DOC_TEMPLATE: str = (
72 "Document {!r} appeared in response but was not present among references"
73)
74_ACTIVE_TXN: str = "There is already an active transaction."
75_INACTIVE_TXN: str = "There is no active transaction."
76_CLIENT_INFO: Any = client_info.ClientInfo(client_library_version=__version__)
77_FIRESTORE_EMULATOR_HOST: str = "FIRESTORE_EMULATOR_HOST"
78
79
80class BaseClient(ClientWithProject):
81 """Client for interacting with Google Cloud Firestore API.
82
83 .. note::
84
85 Since the Cloud Firestore API requires the gRPC transport, no
86 ``_http`` argument is accepted by this class.
87
88 Args:
89 project (Optional[str]): The project which the client acts on behalf
90 of. If not passed, falls back to the default inferred
91 from the environment.
92 credentials (Optional[~google.auth.credentials.Credentials]): The
93 OAuth2 Credentials to use for this client. If not passed, falls
94 back to the default inferred from the environment.
95 database (Optional[str]): The database name that the client targets.
96 For now, :attr:`DEFAULT_DATABASE` (the default value) is the
97 only valid database.
98 client_info (Optional[google.api_core.gapic_v1.client_info.ClientInfo]):
99 The client info used to send a user-agent string along with API
100 requests. If ``None``, then default info will be used. Generally,
101 you only need to set this if you're developing your own library
102 or partner tool.
103 client_options (Union[dict, google.api_core.client_options.ClientOptions]):
104 Client options used to set user options on the client. API Endpoint
105 should be set through client_options.
106 """
107
108 SCOPE = (
109 "https://www.googleapis.com/auth/cloud-platform",
110 "https://www.googleapis.com/auth/datastore",
111 )
112 """The scopes required for authenticating with the Firestore service."""
113
114 _firestore_api_internal = None
115 _database_string_internal = None
116 _rpc_metadata_internal = None
117
118 def __init__(
119 self,
120 project=None,
121 credentials=None,
122 database=None,
123 client_info=_CLIENT_INFO,
124 client_options=None,
125 ) -> None:
126 database = database or DEFAULT_DATABASE
127 # NOTE: This API has no use for the _http argument, but sending it
128 # will have no impact since the _http() @property only lazily
129 # creates a working HTTP object.
130 self._emulator_host = os.getenv(_FIRESTORE_EMULATOR_HOST)
131
132 if self._emulator_host is not None:
133 if credentials is None:
134 credentials = AnonymousCredentials()
135 if project is None:
136 # extract project from env var, or use system default
137 project = (
138 os.getenv("GOOGLE_CLOUD_PROJECT")
139 or os.getenv("GCLOUD_PROJECT")
140 or _DEFAULT_EMULATOR_PROJECT
141 )
142
143 super(BaseClient, self).__init__(
144 project=project,
145 credentials=credentials,
146 client_options=client_options,
147 _http=None,
148 )
149 self._client_info = client_info
150 if client_options:
151 if isinstance(client_options, dict):
152 client_options = google.api_core.client_options.from_dict(
153 client_options
154 )
155 self._client_options = client_options
156
157 self._database = database
158
159 def _firestore_api_helper(self, transport, client_class, client_module) -> Any:
160 """Lazy-loading getter GAPIC Firestore API.
161 Returns:
162 The GAPIC client with the credentials of the current client.
163 """
164 if self._firestore_api_internal is None:
165 # Use a custom channel.
166 # We need this in order to set appropriate keepalive options.
167
168 if self._emulator_host is not None:
169 channel = self._emulator_channel(transport)
170 else:
171 channel = transport.create_channel(
172 self._target,
173 credentials=self._credentials,
174 options={"grpc.keepalive_time_ms": 30000}.items(),
175 )
176
177 self._transport = transport(host=self._target, channel=channel)
178
179 self._firestore_api_internal = client_class(
180 transport=self._transport, client_options=self._client_options
181 )
182 client_module._client_info = self._client_info
183
184 return self._firestore_api_internal
185
186 def _emulator_channel(self, transport):
187 """
188 Creates an insecure channel to communicate with the local emulator.
189 If credentials are provided the token is extracted and added to the
190 headers. This supports local testing of firestore rules if the credentials
191 have been created from a signed custom token.
192
193 :return: grpc.Channel or grpc.aio.Channel
194 """
195 # Insecure channels are used for the emulator as secure channels
196 # cannot be used to communicate on some environments.
197 # https://github.com/googleapis/python-firestore/issues/359
198 # Default the token to a non-empty string, in this case "owner".
199 token = "owner"
200 if (
201 self._credentials is not None
202 and getattr(self._credentials, "id_token", None) is not None
203 ):
204 token = self._credentials.id_token
205 options = [("Authorization", f"Bearer {token}")]
206
207 if "GrpcAsyncIOTransport" in str(transport.__name__):
208 return grpc.aio.insecure_channel(self._emulator_host, options=options)
209 else:
210 return grpc.insecure_channel(self._emulator_host, options=options)
211
212 def _target_helper(self, client_class) -> str:
213 """Return the target (where the API is).
214 Eg. "firestore.googleapis.com"
215
216 Returns:
217 str: The location of the API.
218 """
219 if self._emulator_host is not None:
220 return self._emulator_host
221 elif self._client_options and self._client_options.api_endpoint:
222 return self._client_options.api_endpoint
223 else:
224 return client_class.DEFAULT_ENDPOINT
225
226 @property
227 def _target(self):
228 """Return the target (where the API is).
229 Eg. "firestore.googleapis.com"
230
231 Returns:
232 str: The location of the API.
233 """
234 return self._target_helper(firestore_client.FirestoreClient)
235
236 @property
237 def _database_string(self):
238 """The database string corresponding to this client's project.
239
240 This value is lazy-loaded and cached.
241
242 Will be of the form
243
244 ``projects/{project_id}/databases/{database_id}``
245
246 but ``database_id == '(default)'`` for the time being.
247
248 Returns:
249 str: The fully-qualified database string for the current
250 project. (The default database is also in this string.)
251 """
252 if self._database_string_internal is None:
253 db_str = google.api_core.path_template.expand(
254 "projects/{project}/databases/{database}",
255 project=self.project,
256 database=self._database,
257 )
258
259 self._database_string_internal = db_str
260
261 return self._database_string_internal
262
263 @property
264 def _rpc_metadata(self):
265 """The RPC metadata for this client's associated database.
266
267 Returns:
268 Sequence[Tuple(str, str)]: RPC metadata with resource prefix
269 for the database associated with this client.
270 """
271 if self._rpc_metadata_internal is None:
272 self._rpc_metadata_internal = _helpers.metadata_with_prefix(
273 self._database_string
274 )
275
276 if self._emulator_host is not None:
277 # The emulator requires additional metadata to be set.
278 self._rpc_metadata_internal.append(("authorization", "Bearer owner"))
279
280 return self._rpc_metadata_internal
281
282 def collection(self, *collection_path) -> BaseCollectionReference:
283 raise NotImplementedError
284
285 def collection_group(self, collection_id: str) -> BaseQuery:
286 raise NotImplementedError
287
288 def _get_collection_reference(
289 self, collection_id: str
290 ) -> BaseCollectionReference[BaseQuery]:
291 """Checks validity of collection_id and then uses subclasses collection implementation.
292
293 Args:
294 collection_id (str) Identifies the collections to query over.
295
296 Every collection or subcollection with this ID as the last segment of its
297 path will be included. Cannot contain a slash.
298
299 Returns:
300 The created collection.
301 """
302 if "/" in collection_id:
303 raise ValueError(
304 "Invalid collection_id "
305 + collection_id
306 + ". Collection IDs must not contain '/'."
307 )
308
309 return self.collection(collection_id)
310
311 def document(self, *document_path) -> BaseDocumentReference:
312 raise NotImplementedError
313
314 def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter:
315 """Get a BulkWriter instance from this client.
316
317 Args:
318 :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriterOptions`:
319 Optional control parameters for the
320 :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter` returned.
321
322 Returns:
323 :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter`:
324 A utility to efficiently create and save many `WriteBatch` instances
325 to the server.
326 """
327 return BulkWriter(client=self, options=options)
328
329 def _document_path_helper(self, *document_path) -> List[str]:
330 """Standardize the format of path to tuple of path segments and strip the database string from path if present.
331
332 Args:
333 document_path (Tuple[str, ...]): Can either be
334
335 * A single ``/``-delimited path to a document
336 * A tuple of document path segments
337 """
338 path = _path_helper(document_path)
339 base_path = self._database_string + "/documents/"
340 joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path)
341 if joined_path.startswith(base_path):
342 joined_path = joined_path[len(base_path) :]
343 return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER)
344
345 def recursive_delete(
346 self,
347 reference,
348 *,
349 bulk_writer: Optional["BulkWriter"] = None,
350 chunk_size: int = 5000,
351 ) -> int | Awaitable[int]:
352 raise NotImplementedError
353
354 @staticmethod
355 def field_path(*field_names: str) -> str:
356 """Create a **field path** from a list of nested field names.
357
358 A **field path** is a ``.``-delimited concatenation of the field
359 names. It is used to represent a nested field. For example,
360 in the data
361
362 .. code-block:: python
363
364 data = {
365 'aa': {
366 'bb': {
367 'cc': 10,
368 },
369 },
370 }
371
372 the field path ``'aa.bb.cc'`` represents the data stored in
373 ``data['aa']['bb']['cc']``.
374
375 Args:
376 field_names: The list of field names.
377
378 Returns:
379 str: The ``.``-delimited field path.
380 """
381 return render_field_path(field_names)
382
383 @staticmethod
384 def write_option(
385 **kwargs,
386 ) -> Union[_helpers.ExistsOption, _helpers.LastUpdateOption]:
387 """Create a write option for write operations.
388
389 Write operations include :meth:`~google.cloud.DocumentReference.set`,
390 :meth:`~google.cloud.DocumentReference.update` and
391 :meth:`~google.cloud.DocumentReference.delete`.
392
393 One of the following keyword arguments must be provided:
394
395 * ``last_update_time`` (:class:`google.protobuf.timestamp_pb2.\
396 Timestamp`): A timestamp. When set, the target document must
397 exist and have been last updated at that time. Protobuf
398 ``update_time`` timestamps are typically returned from methods
399 that perform write operations as part of a "write result"
400 protobuf or directly.
401 * ``exists`` (:class:`bool`): Indicates if the document being modified
402 should already exist.
403
404 Providing no argument would make the option have no effect (so
405 it is not allowed). Providing multiple would be an apparent
406 contradiction, since ``last_update_time`` assumes that the
407 document **was** updated (it can't have been updated if it
408 doesn't exist) and ``exists`` indicate that it is unknown if the
409 document exists or not.
410
411 Args:
412 kwargs (Dict[str, Any]): The keyword arguments described above.
413
414 Raises:
415 TypeError: If anything other than exactly one argument is
416 provided by the caller.
417
418 Returns:
419 :class:`~google.cloud.firestore_v1.client.WriteOption`:
420 The option to be used to configure a write message.
421 """
422 if len(kwargs) != 1:
423 raise TypeError(_BAD_OPTION_ERR)
424
425 name, value = kwargs.popitem()
426 if name == "last_update_time":
427 return _helpers.LastUpdateOption(value)
428 elif name == "exists":
429 return _helpers.ExistsOption(value)
430 else:
431 extra = "{!r} was provided".format(name)
432 raise TypeError(_BAD_OPTION_ERR, extra)
433
434 def _prep_get_all(
435 self,
436 references: list,
437 field_paths: Iterable[str] | None = None,
438 transaction: BaseTransaction | None = None,
439 retry: retries.Retry | retries.AsyncRetry | object | None = None,
440 timeout: float | None = None,
441 read_time: datetime.datetime | None = None,
442 ) -> Tuple[dict, dict, dict]:
443 """Shared setup for async/sync :meth:`get_all`."""
444 document_paths, reference_map = _reference_info(references)
445 mask = _get_doc_mask(field_paths)
446 request = {
447 "database": self._database_string,
448 "documents": document_paths,
449 "mask": mask,
450 "transaction": _helpers.get_transaction_id(transaction),
451 }
452 if read_time is not None:
453 request["read_time"] = read_time
454 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
455
456 return request, reference_map, kwargs
457
458 def get_all(
459 self,
460 references: list,
461 field_paths: Iterable[str] | None = None,
462 transaction=None,
463 retry: retries.Retry | retries.AsyncRetry | object | None = None,
464 timeout: float | None = None,
465 *,
466 read_time: datetime.datetime | None = None,
467 ) -> Union[
468 AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any]
469 ]:
470 raise NotImplementedError
471
472 def _prep_collections(
473 self,
474 retry: retries.Retry | retries.AsyncRetry | object | None = None,
475 timeout: float | None = None,
476 read_time: datetime.datetime | None = None,
477 ) -> Tuple[dict, dict]:
478 """Shared setup for async/sync :meth:`collections`."""
479 request: dict[str, Any] = {
480 "parent": "{}/documents".format(self._database_string),
481 }
482 if read_time is not None:
483 request["read_time"] = read_time
484 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
485
486 return request, kwargs
487
488 def collections(
489 self,
490 retry: retries.Retry | retries.AsyncRetry | object | None = None,
491 timeout: float | None = None,
492 *,
493 read_time: datetime.datetime | None = None,
494 ):
495 raise NotImplementedError
496
497 def batch(self) -> BaseWriteBatch:
498 raise NotImplementedError
499
500 def transaction(
501 self, max_attempts: int = MAX_ATTEMPTS, read_only: bool = False
502 ) -> BaseTransaction:
503 raise NotImplementedError
504
505
506def _reference_info(references: list) -> Tuple[list, dict]:
507 """Get information about document references.
508
509 Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`.
510
511 Args:
512 references (List[.DocumentReference, ...]): Iterable of document
513 references.
514
515 Returns:
516 Tuple[List[str, ...], Dict[str, .DocumentReference]]: A two-tuple of
517
518 * fully-qualified documents paths for each reference in ``references``
519 * a mapping from the paths to the original reference. (If multiple
520 ``references`` contains multiple references to the same document,
521 that key will be overwritten in the result.)
522 """
523 document_paths = []
524 reference_map = {}
525 for reference in references:
526 doc_path = reference._document_path
527 document_paths.append(doc_path)
528 reference_map[doc_path] = reference
529
530 return document_paths, reference_map
531
532
533def _get_reference(document_path: str, reference_map: dict) -> BaseDocumentReference:
534 """Get a document reference from a dictionary.
535
536 This just wraps a simple dictionary look-up with a helpful error that is
537 specific to :meth:`~google.cloud.firestore.client.Client.get_all`, the
538 **public** caller of this function.
539
540 Args:
541 document_path (str): A fully-qualified document path.
542 reference_map (Dict[str, .DocumentReference]): A mapping (produced
543 by :func:`_reference_info`) of fully-qualified document paths to
544 document references.
545
546 Returns:
547 .DocumentReference: The matching reference.
548
549 Raises:
550 ValueError: If ``document_path`` has not been encountered.
551 """
552 try:
553 return reference_map[document_path]
554 except KeyError:
555 msg = _BAD_DOC_TEMPLATE.format(document_path)
556 raise ValueError(msg)
557
558
559def _parse_batch_get(
560 get_doc_response: types.BatchGetDocumentsResponse,
561 reference_map: dict,
562 client: BaseClient,
563) -> DocumentSnapshot:
564 """Parse a `BatchGetDocumentsResponse` protobuf.
565
566 Args:
567 get_doc_response (~google.cloud.firestore_v1.\
568 firestore.BatchGetDocumentsResponse): A single response (from
569 a stream) containing the "get" response for a document.
570 reference_map (Dict[str, .DocumentReference]): A mapping (produced
571 by :func:`_reference_info`) of fully-qualified document paths to
572 document references.
573 client (:class:`~google.cloud.firestore_v1.client.Client`):
574 A client that has a document factory.
575
576 Returns:
577 [.DocumentSnapshot]: The retrieved snapshot.
578
579 Raises:
580 ValueError: If the response has a ``result`` field (a oneof) other
581 than ``found`` or ``missing``.
582 """
583 result_type = get_doc_response._pb.WhichOneof("result")
584 if result_type == "found":
585 reference = _get_reference(get_doc_response.found.name, reference_map)
586 data = _helpers.decode_dict(get_doc_response.found.fields, client)
587 snapshot = DocumentSnapshot(
588 reference,
589 data,
590 exists=True,
591 read_time=get_doc_response.read_time,
592 create_time=get_doc_response.found.create_time,
593 update_time=get_doc_response.found.update_time,
594 )
595 elif result_type == "missing":
596 reference = _get_reference(get_doc_response.missing, reference_map)
597 snapshot = DocumentSnapshot(
598 reference,
599 None,
600 exists=False,
601 read_time=get_doc_response.read_time,
602 create_time=None,
603 update_time=None,
604 )
605 else:
606 raise ValueError(
607 "`BatchGetDocumentsResponse.result` (a oneof) had a field other "
608 "than `found` or `missing` set, or was unset"
609 )
610 return snapshot
611
612
613def _get_doc_mask(
614 field_paths: Iterable[str] | None,
615) -> Optional[types.common.DocumentMask]:
616 """Get a document mask if field paths are provided.
617
618 Args:
619 field_paths (Optional[Iterable[str, ...]]): An iterable of field
620 paths (``.``-delimited list of field names) to use as a
621 projection of document fields in the returned results.
622
623 Returns:
624 Optional[google.cloud.firestore_v1.types.common.DocumentMask]: A mask
625 to project documents to a restricted set of field paths.
626 """
627 if field_paths is None:
628 return None
629 else:
630 return types.DocumentMask(field_paths=field_paths)
631
632
633def _path_helper(path: tuple) -> Tuple[str]:
634 """Standardize path into a tuple of path segments.
635
636 Args:
637 path (Tuple[str, ...]): Can either be
638
639 * A single ``/``-delimited path
640 * A tuple of path segments
641 """
642 if len(path) == 1:
643 return path[0].split(_helpers.DOCUMENT_PATH_DELIMITER)
644 else:
645 return path