1# Copyright 2017 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Client for interacting with the Google Cloud Firestore API.
16
17This is the base from which all interactions with the API occur.
18
19In the hierarchy of API concepts
20
21* a :class:`~google.cloud.firestore_v1.client.Client` owns a
22 :class:`~google.cloud.firestore_v1.collection.CollectionReference`
23* a :class:`~google.cloud.firestore_v1.client.Client` owns a
24 :class:`~google.cloud.firestore_v1.document.DocumentReference`
25"""
26from __future__ import annotations
27
28import datetime
29import os
30from typing import (
31 Any,
32 AsyncGenerator,
33 Awaitable,
34 Generator,
35 Iterable,
36 List,
37 Optional,
38 Tuple,
39 Union,
40)
41
42import google.api_core.client_options
43import google.api_core.path_template
44import grpc # type: ignore
45from google.api_core import retry as retries
46from google.api_core.gapic_v1 import client_info
47from google.auth.credentials import AnonymousCredentials
48from google.cloud.client import ClientWithProject # type: ignore
49
50from google.cloud.firestore_v1 import __version__, _helpers, types
51from google.cloud.firestore_v1.base_batch import BaseWriteBatch
52
53# Types needed only for Type Hints
54from google.cloud.firestore_v1.base_collection import BaseCollectionReference
55from google.cloud.firestore_v1.base_document import (
56 BaseDocumentReference,
57 DocumentSnapshot,
58)
59from google.cloud.firestore_v1.base_query import BaseQuery
60from google.cloud.firestore_v1.base_transaction import BaseTransaction
61from google.cloud.firestore_v1.bulk_writer import BulkWriter, BulkWriterOptions
62from google.cloud.firestore_v1.field_path import render_field_path
63from google.cloud.firestore_v1.services.firestore import client as firestore_client
64
65DEFAULT_DATABASE = "(default)"
66"""str: The default database used in a :class:`~google.cloud.firestore_v1.client.Client`."""
67_DEFAULT_EMULATOR_PROJECT = "google-cloud-firestore-emulator"
68_BAD_OPTION_ERR = (
69 "Exactly one of ``last_update_time`` or ``exists`` " "must be provided."
70)
71_BAD_DOC_TEMPLATE: str = (
72 "Document {!r} appeared in response but was not present among references"
73)
74_ACTIVE_TXN: str = "There is already an active transaction."
75_INACTIVE_TXN: str = "There is no active transaction."
76_CLIENT_INFO: Any = client_info.ClientInfo(client_library_version=__version__)
77_FIRESTORE_EMULATOR_HOST: str = "FIRESTORE_EMULATOR_HOST"
78
79
80class BaseClient(ClientWithProject):
81 """Client for interacting with Google Cloud Firestore API.
82
83 .. note::
84
85 Since the Cloud Firestore API requires the gRPC transport, no
86 ``_http`` argument is accepted by this class.
87
88 Args:
89 project (Optional[str]): The project which the client acts on behalf
90 of. If not passed, falls back to the default inferred
91 from the environment.
92 credentials (Optional[~google.auth.credentials.Credentials]): The
93 OAuth2 Credentials to use for this client. If not passed, falls
94 back to the default inferred from the environment.
95 database (Optional[str]): The database name that the client targets.
96 For now, :attr:`DEFAULT_DATABASE` (the default value) is the
97 only valid database.
98 client_info (Optional[google.api_core.gapic_v1.client_info.ClientInfo]):
99 The client info used to send a user-agent string along with API
100 requests. If ``None``, then default info will be used. Generally,
101 you only need to set this if you're developing your own library
102 or partner tool.
103 client_options (Union[dict, google.api_core.client_options.ClientOptions]):
104 Client options used to set user options on the client. API Endpoint
105 should be set through client_options.
106 """
107
108 SCOPE = (
109 "https://www.googleapis.com/auth/cloud-platform",
110 "https://www.googleapis.com/auth/datastore",
111 )
112 """The scopes required for authenticating with the Firestore service."""
113
114 _firestore_api_internal = None
115 _database_string_internal = None
116 _rpc_metadata_internal = None
117
118 def __init__(
119 self,
120 project=None,
121 credentials=None,
122 database=None,
123 client_info=_CLIENT_INFO,
124 client_options=None,
125 ) -> None:
126 database = database or DEFAULT_DATABASE
127 # NOTE: This API has no use for the _http argument, but sending it
128 # will have no impact since the _http() @property only lazily
129 # creates a working HTTP object.
130 self._emulator_host = os.getenv(_FIRESTORE_EMULATOR_HOST)
131
132 if self._emulator_host is not None:
133 if credentials is None:
134 credentials = AnonymousCredentials()
135 if project is None:
136 # extract project from env var, or use system default
137 project = (
138 os.getenv("GOOGLE_CLOUD_PROJECT")
139 or os.getenv("GCLOUD_PROJECT")
140 or _DEFAULT_EMULATOR_PROJECT
141 )
142
143 super(BaseClient, self).__init__(
144 project=project,
145 credentials=credentials,
146 client_options=client_options,
147 _http=None,
148 )
149 self._client_info = client_info
150 if client_options:
151 if isinstance(client_options, dict):
152 client_options = google.api_core.client_options.from_dict(
153 client_options
154 )
155 self._client_options = client_options
156
157 self._database = database
158
159 def _firestore_api_helper(self, transport, client_class, client_module) -> Any:
160 """Lazy-loading getter GAPIC Firestore API.
161 Returns:
162 The GAPIC client with the credentials of the current client.
163 """
164 if self._firestore_api_internal is None:
165 # Use a custom channel.
166 # We need this in order to set appropriate keepalive options.
167
168 if self._emulator_host is not None:
169 channel = self._emulator_channel(transport)
170 else:
171 channel = transport.create_channel(
172 self._target,
173 credentials=self._credentials,
174 options={"grpc.keepalive_time_ms": 30000}.items(),
175 )
176
177 self._transport = transport(host=self._target, channel=channel)
178
179 self._firestore_api_internal = client_class(
180 transport=self._transport, client_options=self._client_options
181 )
182 client_module._client_info = self._client_info
183
184 return self._firestore_api_internal
185
186 def _emulator_channel(self, transport):
187 """
188 Creates an insecure channel to communicate with the local emulator.
189 If credentials are provided the token is extracted and added to the
190 headers. This supports local testing of firestore rules if the credentials
191 have been created from a signed custom token.
192
193 :return: grpc.Channel or grpc.aio.Channel
194 """
195 # Insecure channels are used for the emulator as secure channels
196 # cannot be used to communicate on some environments.
197 # https://github.com/googleapis/python-firestore/issues/359
198 # Default the token to a non-empty string, in this case "owner".
199 token = "owner"
200 if (
201 self._credentials is not None
202 and getattr(self._credentials, "id_token", None) is not None
203 ):
204 token = self._credentials.id_token
205 options = [("Authorization", f"Bearer {token}")]
206
207 if "GrpcAsyncIOTransport" in str(transport.__name__):
208 return grpc.aio.insecure_channel(self._emulator_host, options=options)
209 else:
210 return grpc.insecure_channel(self._emulator_host, options=options)
211
212 def _target_helper(self, client_class) -> str:
213 """Return the target (where the API is).
214 Eg. "firestore.googleapis.com"
215
216 Returns:
217 str: The location of the API.
218 """
219 if self._emulator_host is not None:
220 return self._emulator_host
221 elif self._client_options and self._client_options.api_endpoint:
222 return self._client_options.api_endpoint
223 else:
224 return client_class.DEFAULT_ENDPOINT
225
226 @property
227 def _target(self):
228 """Return the target (where the API is).
229 Eg. "firestore.googleapis.com"
230
231 Returns:
232 str: The location of the API.
233 """
234 return self._target_helper(firestore_client.FirestoreClient)
235
236 @property
237 def _database_string(self):
238 """The database string corresponding to this client's project.
239
240 This value is lazy-loaded and cached.
241
242 Will be of the form
243
244 ``projects/{project_id}/databases/{database_id}``
245
246 but ``database_id == '(default)'`` for the time being.
247
248 Returns:
249 str: The fully-qualified database string for the current
250 project. (The default database is also in this string.)
251 """
252 if self._database_string_internal is None:
253 db_str = google.api_core.path_template.expand(
254 "projects/{project}/databases/{database}",
255 project=self.project,
256 database=self._database,
257 )
258
259 self._database_string_internal = db_str
260
261 return self._database_string_internal
262
263 @property
264 def _rpc_metadata(self):
265 """The RPC metadata for this client's associated database.
266
267 Returns:
268 Sequence[Tuple(str, str)]: RPC metadata with resource prefix
269 for the database associated with this client.
270 """
271 if self._rpc_metadata_internal is None:
272 self._rpc_metadata_internal = _helpers.metadata_with_prefix(
273 self._database_string
274 )
275
276 if self._emulator_host is not None:
277 # The emulator requires additional metadata to be set.
278 self._rpc_metadata_internal.append(("authorization", "Bearer owner"))
279
280 return self._rpc_metadata_internal
281
282 def collection(self, *collection_path) -> BaseCollectionReference:
283 raise NotImplementedError
284
285 def collection_group(self, collection_id: str) -> BaseQuery:
286 raise NotImplementedError
287
288 def _get_collection_reference(
289 self, collection_id: str
290 ) -> BaseCollectionReference[BaseQuery]:
291 """Checks validity of collection_id and then uses subclasses collection implementation.
292
293 Args:
294 collection_id (str) Identifies the collections to query over.
295
296 Every collection or subcollection with this ID as the last segment of its
297 path will be included. Cannot contain a slash.
298
299 Returns:
300 The created collection.
301 """
302 if "/" in collection_id:
303 raise ValueError(
304 "Invalid collection_id "
305 + collection_id
306 + ". Collection IDs must not contain '/'."
307 )
308
309 return self.collection(collection_id)
310
311 def document(self, *document_path) -> BaseDocumentReference:
312 raise NotImplementedError
313
314 def bulk_writer(self, options: Optional[BulkWriterOptions] = None) -> BulkWriter:
315 """Get a BulkWriter instance from this client.
316
317 Args:
318 :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriterOptions`:
319 Optional control parameters for the
320 :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter` returned.
321
322 Returns:
323 :class:`@google.cloud.firestore_v1.bulk_writer.BulkWriter`:
324 A utility to efficiently create and save many `WriteBatch` instances
325 to the server.
326 """
327 return BulkWriter(client=self, options=options)
328
329 def _document_path_helper(self, *document_path) -> List[str]:
330 """Standardize the format of path to tuple of path segments and strip the database string from path if present.
331
332 Args:
333 document_path (Tuple[str, ...]): Can either be
334
335 * A single ``/``-delimited path to a document
336 * A tuple of document path segments
337 """
338 path = _path_helper(document_path)
339 base_path = self._database_string + "/documents/"
340 joined_path = _helpers.DOCUMENT_PATH_DELIMITER.join(path)
341 if joined_path.startswith(base_path):
342 joined_path = joined_path[len(base_path) :]
343 return joined_path.split(_helpers.DOCUMENT_PATH_DELIMITER)
344
345 def recursive_delete(
346 self,
347 reference,
348 *,
349 bulk_writer: Optional["BulkWriter"] = None,
350 chunk_size: int = 5000,
351 ) -> int | Awaitable[int]:
352 raise NotImplementedError
353
354 @staticmethod
355 def field_path(*field_names: str) -> str:
356 """Create a **field path** from a list of nested field names.
357
358 A **field path** is a ``.``-delimited concatenation of the field
359 names. It is used to represent a nested field. For example,
360 in the data
361
362 .. code-block:: python
363
364 data = {
365 'aa': {
366 'bb': {
367 'cc': 10,
368 },
369 },
370 }
371
372 the field path ``'aa.bb.cc'`` represents the data stored in
373 ``data['aa']['bb']['cc']``.
374
375 Args:
376 field_names: The list of field names.
377
378 Returns:
379 str: The ``.``-delimited field path.
380 """
381 return render_field_path(field_names)
382
383 @staticmethod
384 def write_option(
385 **kwargs,
386 ) -> Union[_helpers.ExistsOption, _helpers.LastUpdateOption]:
387 """Create a write option for write operations.
388
389 Write operations include :meth:`~google.cloud.DocumentReference.set`,
390 :meth:`~google.cloud.DocumentReference.update` and
391 :meth:`~google.cloud.DocumentReference.delete`.
392
393 One of the following keyword arguments must be provided:
394
395 * ``last_update_time`` (:class:`google.protobuf.timestamp_pb2.\
396 Timestamp`): A timestamp. When set, the target document must
397 exist and have been last updated at that time. Protobuf
398 ``update_time`` timestamps are typically returned from methods
399 that perform write operations as part of a "write result"
400 protobuf or directly.
401 * ``exists`` (:class:`bool`): Indicates if the document being modified
402 should already exist.
403
404 Providing no argument would make the option have no effect (so
405 it is not allowed). Providing multiple would be an apparent
406 contradiction, since ``last_update_time`` assumes that the
407 document **was** updated (it can't have been updated if it
408 doesn't exist) and ``exists`` indicate that it is unknown if the
409 document exists or not.
410
411 Args:
412 kwargs (Dict[str, Any]): The keyword arguments described above.
413
414 Raises:
415 TypeError: If anything other than exactly one argument is
416 provided by the caller.
417
418 Returns:
419 :class:`~google.cloud.firestore_v1.client.WriteOption`:
420 The option to be used to configure a write message.
421 """
422 if len(kwargs) != 1:
423 raise TypeError(_BAD_OPTION_ERR)
424
425 name, value = kwargs.popitem()
426 if name == "last_update_time":
427 return _helpers.LastUpdateOption(value)
428 elif name == "exists":
429 return _helpers.ExistsOption(value)
430 else:
431 extra = "{!r} was provided".format(name)
432 raise TypeError(_BAD_OPTION_ERR, extra)
433
434 def _prep_get_all(
435 self,
436 references: list,
437 field_paths: Iterable[str] | None = None,
438 transaction: BaseTransaction | None = None,
439 retry: retries.Retry | retries.AsyncRetry | object | None = None,
440 timeout: float | None = None,
441 read_time: datetime.datetime | None = None,
442 ) -> Tuple[dict, dict, dict]:
443 """Shared setup for async/sync :meth:`get_all`."""
444 document_paths, reference_map = _reference_info(references)
445 mask = _get_doc_mask(field_paths)
446 request = {
447 "database": self._database_string,
448 "documents": document_paths,
449 "mask": mask,
450 "transaction": _helpers.get_transaction_id(transaction),
451 }
452 if read_time is not None:
453 request["read_time"] = read_time
454 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
455
456 return request, reference_map, kwargs
457
458 def get_all(
459 self,
460 references: list,
461 field_paths: Iterable[str] | None = None,
462 transaction=None,
463 retry: retries.Retry | retries.AsyncRetry | object | None = None,
464 timeout: float | None = None,
465 *,
466 read_time: datetime.datetime | None = None,
467 ) -> Union[
468 AsyncGenerator[DocumentSnapshot, Any], Generator[DocumentSnapshot, Any, Any]
469 ]:
470 raise NotImplementedError
471
472 def _prep_collections(
473 self,
474 retry: retries.Retry | retries.AsyncRetry | object | None = None,
475 timeout: float | None = None,
476 read_time: datetime.datetime | None = None,
477 ) -> Tuple[dict, dict]:
478 """Shared setup for async/sync :meth:`collections`."""
479 request: dict[str, Any] = {
480 "parent": "{}/documents".format(self._database_string),
481 }
482 if read_time is not None:
483 request["read_time"] = read_time
484 kwargs = _helpers.make_retry_timeout_kwargs(retry, timeout)
485
486 return request, kwargs
487
488 def collections(
489 self,
490 retry: retries.Retry | retries.AsyncRetry | object | None = None,
491 timeout: float | None = None,
492 *,
493 read_time: datetime.datetime | None = None,
494 ):
495 raise NotImplementedError
496
497 def batch(self) -> BaseWriteBatch:
498 raise NotImplementedError
499
500 def transaction(self, **kwargs) -> BaseTransaction:
501 raise NotImplementedError
502
503
504def _reference_info(references: list) -> Tuple[list, dict]:
505 """Get information about document references.
506
507 Helper for :meth:`~google.cloud.firestore_v1.client.Client.get_all`.
508
509 Args:
510 references (List[.DocumentReference, ...]): Iterable of document
511 references.
512
513 Returns:
514 Tuple[List[str, ...], Dict[str, .DocumentReference]]: A two-tuple of
515
516 * fully-qualified documents paths for each reference in ``references``
517 * a mapping from the paths to the original reference. (If multiple
518 ``references`` contains multiple references to the same document,
519 that key will be overwritten in the result.)
520 """
521 document_paths = []
522 reference_map = {}
523 for reference in references:
524 doc_path = reference._document_path
525 document_paths.append(doc_path)
526 reference_map[doc_path] = reference
527
528 return document_paths, reference_map
529
530
531def _get_reference(document_path: str, reference_map: dict) -> BaseDocumentReference:
532 """Get a document reference from a dictionary.
533
534 This just wraps a simple dictionary look-up with a helpful error that is
535 specific to :meth:`~google.cloud.firestore.client.Client.get_all`, the
536 **public** caller of this function.
537
538 Args:
539 document_path (str): A fully-qualified document path.
540 reference_map (Dict[str, .DocumentReference]): A mapping (produced
541 by :func:`_reference_info`) of fully-qualified document paths to
542 document references.
543
544 Returns:
545 .DocumentReference: The matching reference.
546
547 Raises:
548 ValueError: If ``document_path`` has not been encountered.
549 """
550 try:
551 return reference_map[document_path]
552 except KeyError:
553 msg = _BAD_DOC_TEMPLATE.format(document_path)
554 raise ValueError(msg)
555
556
557def _parse_batch_get(
558 get_doc_response: types.BatchGetDocumentsResponse,
559 reference_map: dict,
560 client: BaseClient,
561) -> DocumentSnapshot:
562 """Parse a `BatchGetDocumentsResponse` protobuf.
563
564 Args:
565 get_doc_response (~google.cloud.firestore_v1.\
566 firestore.BatchGetDocumentsResponse): A single response (from
567 a stream) containing the "get" response for a document.
568 reference_map (Dict[str, .DocumentReference]): A mapping (produced
569 by :func:`_reference_info`) of fully-qualified document paths to
570 document references.
571 client (:class:`~google.cloud.firestore_v1.client.Client`):
572 A client that has a document factory.
573
574 Returns:
575 [.DocumentSnapshot]: The retrieved snapshot.
576
577 Raises:
578 ValueError: If the response has a ``result`` field (a oneof) other
579 than ``found`` or ``missing``.
580 """
581 result_type = get_doc_response._pb.WhichOneof("result")
582 if result_type == "found":
583 reference = _get_reference(get_doc_response.found.name, reference_map)
584 data = _helpers.decode_dict(get_doc_response.found.fields, client)
585 snapshot = DocumentSnapshot(
586 reference,
587 data,
588 exists=True,
589 read_time=get_doc_response.read_time,
590 create_time=get_doc_response.found.create_time,
591 update_time=get_doc_response.found.update_time,
592 )
593 elif result_type == "missing":
594 reference = _get_reference(get_doc_response.missing, reference_map)
595 snapshot = DocumentSnapshot(
596 reference,
597 None,
598 exists=False,
599 read_time=get_doc_response.read_time,
600 create_time=None,
601 update_time=None,
602 )
603 else:
604 raise ValueError(
605 "`BatchGetDocumentsResponse.result` (a oneof) had a field other "
606 "than `found` or `missing` set, or was unset"
607 )
608 return snapshot
609
610
611def _get_doc_mask(
612 field_paths: Iterable[str] | None,
613) -> Optional[types.common.DocumentMask]:
614 """Get a document mask if field paths are provided.
615
616 Args:
617 field_paths (Optional[Iterable[str, ...]]): An iterable of field
618 paths (``.``-delimited list of field names) to use as a
619 projection of document fields in the returned results.
620
621 Returns:
622 Optional[google.cloud.firestore_v1.types.common.DocumentMask]: A mask
623 to project documents to a restricted set of field paths.
624 """
625 if field_paths is None:
626 return None
627 else:
628 return types.DocumentMask(field_paths=field_paths)
629
630
631def _path_helper(path: tuple) -> Tuple[str]:
632 """Standardize path into a tuple of path segments.
633
634 Args:
635 path (Tuple[str, ...]): Can either be
636
637 * A single ``/``-delimited path
638 * A tuple of path segments
639 """
640 if len(path) == 1:
641 return path[0].split(_helpers.DOCUMENT_PATH_DELIMITER)
642 else:
643 return path