1# Copyright 2017 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Common helpers shared across Google Cloud Firestore modules."""
16from __future__ import annotations
17import datetime
18import json
19from typing import (
20 Any,
21 Dict,
22 Generator,
23 Iterator,
24 List,
25 Optional,
26 Sequence,
27 Tuple,
28 Union,
29 cast,
30 TYPE_CHECKING,
31)
32
33import grpc # type: ignore
34from google.api_core import gapic_v1
35from google.api_core import retry as retries
36from google.api_core.datetime_helpers import DatetimeWithNanoseconds
37from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
38from google.protobuf import struct_pb2
39from google.protobuf.timestamp_pb2 import Timestamp # type: ignore
40from google.type import latlng_pb2 # type: ignore
41
42import google
43from google.cloud import exceptions # type: ignore
44from google.cloud.firestore_v1 import transforms, types
45from google.cloud.firestore_v1.field_path import FieldPath, parse_field_path
46from google.cloud.firestore_v1.types import common, document, write
47from google.cloud.firestore_v1.types.write import DocumentTransform
48from google.cloud.firestore_v1.vector import Vector
49
50if TYPE_CHECKING: # pragma: NO COVER
51 from google.cloud.firestore_v1 import DocumentSnapshot
52
53_EmptyDict: transforms.Sentinel
54_GRPC_ERROR_MAPPING: dict
55
56
57BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}."
58DOCUMENT_PATH_DELIMITER = "/"
59INACTIVE_TXN = "Transaction not in progress, cannot be used in API requests."
60READ_AFTER_WRITE_ERROR = "Attempted read after write in a transaction."
61BAD_REFERENCE_ERROR = (
62 "Reference value {!r} in unexpected format, expected to be of the form "
63 "``projects/{{project}}/databases/{{database}}/"
64 "documents/{{document_path}}``."
65)
66WRONG_APP_REFERENCE = (
67 "Document {!r} does not correspond to the same database " "({!r}) as the client."
68)
69REQUEST_TIME_ENUM = DocumentTransform.FieldTransform.ServerValue.REQUEST_TIME
70_GRPC_ERROR_MAPPING = {
71 grpc.StatusCode.ALREADY_EXISTS: exceptions.Conflict,
72 grpc.StatusCode.NOT_FOUND: exceptions.NotFound,
73}
74
75
76class GeoPoint(object):
77 """Simple container for a geo point value.
78
79 Args:
80 latitude (float): Latitude of a point.
81 longitude (float): Longitude of a point.
82 """
83
84 def __init__(self, latitude, longitude) -> None:
85 self.latitude = latitude
86 self.longitude = longitude
87
88 def to_protobuf(self) -> latlng_pb2.LatLng:
89 """Convert the current object to protobuf.
90
91 Returns:
92 google.type.latlng_pb2.LatLng: The current point as a protobuf.
93 """
94 return latlng_pb2.LatLng(latitude=self.latitude, longitude=self.longitude)
95
96 def __eq__(self, other):
97 """Compare two geo points for equality.
98
99 Returns:
100 Union[bool, NotImplemented]: :data:`True` if the points compare
101 equal, else :data:`False`. (Or :data:`NotImplemented` if
102 ``other`` is not a geo point.)
103 """
104 if not isinstance(other, GeoPoint):
105 return NotImplemented
106
107 return self.latitude == other.latitude and self.longitude == other.longitude
108
109 def __ne__(self, other):
110 """Compare two geo points for inequality.
111
112 Returns:
113 Union[bool, NotImplemented]: :data:`False` if the points compare
114 equal, else :data:`True`. (Or :data:`NotImplemented` if
115 ``other`` is not a geo point.)
116 """
117 equality_val = self.__eq__(other)
118 if equality_val is NotImplemented:
119 return NotImplemented
120 else:
121 return not equality_val
122
123
124def verify_path(path, is_collection) -> None:
125 """Verifies that a ``path`` has the correct form.
126
127 Checks that all of the elements in ``path`` are strings.
128
129 Args:
130 path (Tuple[str, ...]): The components in a collection or
131 document path.
132 is_collection (bool): Indicates if the ``path`` represents
133 a document or a collection.
134
135 Raises:
136 ValueError: if
137
138 * the ``path`` is empty
139 * ``is_collection=True`` and there are an even number of elements
140 * ``is_collection=False`` and there are an odd number of elements
141 * an element is not a string
142 """
143 num_elements = len(path)
144 if num_elements == 0:
145 raise ValueError("Document or collection path cannot be empty")
146
147 if is_collection:
148 if num_elements % 2 == 0:
149 raise ValueError("A collection must have an odd number of path elements")
150
151 else:
152 if num_elements % 2 == 1:
153 raise ValueError("A document must have an even number of path elements")
154
155 for element in path:
156 if not isinstance(element, str):
157 msg = BAD_PATH_TEMPLATE.format(element, type(element))
158 raise ValueError(msg)
159
160
161def encode_value(value) -> types.document.Value:
162 """Converts a native Python value into a Firestore protobuf ``Value``.
163
164 Args:
165 value (Union[NoneType, bool, int, float, datetime.datetime, \
166 str, bytes, dict, ~google.cloud.Firestore.GeoPoint, \
167 ~google.cloud.firestore_v1.vector.Vector]): A native
168 Python value to convert to a protobuf field.
169
170 Returns:
171 ~google.cloud.firestore_v1.types.Value: A
172 value encoded as a Firestore protobuf.
173
174 Raises:
175 TypeError: If the ``value`` is not one of the accepted types.
176 """
177 if value is None:
178 return document.Value(null_value=struct_pb2.NULL_VALUE)
179
180 # Must come before int since ``bool`` is an integer subtype.
181 if isinstance(value, bool):
182 return document.Value(boolean_value=value)
183
184 if isinstance(value, int):
185 return document.Value(integer_value=value)
186
187 if isinstance(value, float):
188 return document.Value(double_value=value)
189
190 if isinstance(value, DatetimeWithNanoseconds):
191 return document.Value(timestamp_value=value.timestamp_pb())
192
193 if isinstance(value, datetime.datetime):
194 return document.Value(timestamp_value=_datetime_to_pb_timestamp(value))
195
196 if isinstance(value, str):
197 return document.Value(string_value=value)
198
199 if isinstance(value, bytes):
200 return document.Value(bytes_value=value)
201
202 # NOTE: We avoid doing an isinstance() check for a Document
203 # here to avoid import cycles.
204 document_path = getattr(value, "_document_path", None)
205 if document_path is not None:
206 return document.Value(reference_value=document_path)
207
208 if isinstance(value, GeoPoint):
209 return document.Value(geo_point_value=value.to_protobuf())
210
211 if isinstance(value, (list, tuple, set, frozenset)):
212 value_list = tuple(encode_value(element) for element in value)
213 value_pb = document.ArrayValue(values=value_list)
214 return document.Value(array_value=value_pb)
215
216 if isinstance(value, Vector):
217 return encode_value(value.to_map_value())
218
219 if isinstance(value, dict):
220 value_dict = encode_dict(value)
221 value_pb = document.MapValue(fields=value_dict)
222 return document.Value(map_value=value_pb)
223
224 raise TypeError(
225 "Cannot convert to a Firestore Value", value, "Invalid type", type(value)
226 )
227
228
229def encode_dict(values_dict) -> dict:
230 """Encode a dictionary into protobuf ``Value``-s.
231
232 Args:
233 values_dict (dict): The dictionary to encode as protobuf fields.
234
235 Returns:
236 Dict[str, ~google.cloud.firestore_v1.types.Value]: A
237 dictionary of string keys and ``Value`` protobufs as dictionary
238 values.
239 """
240 return {key: encode_value(value) for key, value in values_dict.items()}
241
242
243def document_snapshot_to_protobuf(
244 snapshot: "DocumentSnapshot",
245) -> Optional["google.cloud.firestore_v1.types.Document"]:
246 from google.cloud.firestore_v1.types import Document
247
248 if not snapshot.exists:
249 return None
250
251 return Document(
252 name=snapshot.reference._document_path,
253 fields=encode_dict(snapshot._data),
254 create_time=snapshot.create_time,
255 update_time=snapshot.update_time,
256 )
257
258
259class DocumentReferenceValue:
260 """DocumentReference path container with accessors for each relevant chunk.
261
262 Usage:
263 doc_ref_val = DocumentReferenceValue(
264 'projects/my-proj/databases/(default)/documents/my-col/my-doc',
265 )
266 assert doc_ref_val.project_name == 'my-proj'
267 assert doc_ref_val.collection_name == 'my-col'
268 assert doc_ref_val.document_id == 'my-doc'
269 assert doc_ref_val.database_name == '(default)'
270
271 Raises:
272 ValueError: If the supplied value cannot satisfy a complete path.
273 """
274
275 def __init__(self, reference_value: str):
276 self._reference_value = reference_value
277
278 # The first 5 parts are
279 # projects, {project}, databases, {database}, documents
280 parts = reference_value.split(DOCUMENT_PATH_DELIMITER)
281 if len(parts) < 7:
282 msg = BAD_REFERENCE_ERROR.format(reference_value)
283 raise ValueError(msg)
284
285 self.project_name = parts[1]
286 self.collection_name = parts[5]
287 self.database_name = parts[3]
288 self.document_id = "/".join(parts[6:])
289
290 @property
291 def full_key(self) -> str:
292 """Computed property for a DocumentReference's collection_name and
293 document Id"""
294 return "/".join([self.collection_name, self.document_id])
295
296 @property
297 def full_path(self) -> str:
298 return self._reference_value or "/".join(
299 [
300 "projects",
301 self.project_name,
302 "databases",
303 self.database_name,
304 "documents",
305 self.collection_name,
306 self.document_id,
307 ]
308 )
309
310
311def reference_value_to_document(reference_value, client) -> Any:
312 """Convert a reference value string to a document.
313
314 Args:
315 reference_value (str): A document reference value.
316 client (:class:`~google.cloud.firestore_v1.client.Client`):
317 A client that has a document factory.
318
319 Returns:
320 :class:`~google.cloud.firestore_v1.document.DocumentReference`:
321 The document corresponding to ``reference_value``.
322
323 Raises:
324 ValueError: If the ``reference_value`` is not of the expected
325 format: ``projects/{project}/databases/{database}/documents/...``.
326 ValueError: If the ``reference_value`` does not come from the same
327 project / database combination as the ``client``.
328 """
329 from google.cloud.firestore_v1.base_document import BaseDocumentReference
330
331 doc_ref_value = DocumentReferenceValue(reference_value)
332
333 document: BaseDocumentReference = client.document(doc_ref_value.full_key)
334 if document._document_path != reference_value:
335 msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string)
336 raise ValueError(msg)
337
338 return document
339
340
341def decode_value(
342 value, client
343) -> Union[
344 None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint, Vector
345]:
346 """Converts a Firestore protobuf ``Value`` to a native Python value.
347
348 Args:
349 value (google.cloud.firestore_v1.types.Value): A
350 Firestore protobuf to be decoded / parsed / converted.
351 client (:class:`~google.cloud.firestore_v1.client.Client`):
352 A client that has a document factory.
353
354 Returns:
355 Union[NoneType, bool, int, float, datetime.datetime, \
356 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]: A native
357 Python value converted from the ``value``.
358
359 Raises:
360 NotImplementedError: If the ``value_type`` is ``reference_value``.
361 ValueError: If the ``value_type`` is unknown.
362 """
363 value_pb = getattr(value, "_pb", value)
364 value_type = value_pb.WhichOneof("value_type")
365
366 if value_type == "null_value":
367 return None
368 elif value_type == "boolean_value":
369 return value_pb.boolean_value
370 elif value_type == "integer_value":
371 return value_pb.integer_value
372 elif value_type == "double_value":
373 return value_pb.double_value
374 elif value_type == "timestamp_value":
375 return DatetimeWithNanoseconds.from_timestamp_pb(value_pb.timestamp_value)
376 elif value_type == "string_value":
377 return value_pb.string_value
378 elif value_type == "bytes_value":
379 return value_pb.bytes_value
380 elif value_type == "reference_value":
381 return reference_value_to_document(value_pb.reference_value, client)
382 elif value_type == "geo_point_value":
383 return GeoPoint(
384 value_pb.geo_point_value.latitude, value_pb.geo_point_value.longitude
385 )
386 elif value_type == "array_value":
387 return [
388 decode_value(element, client) for element in value_pb.array_value.values
389 ]
390 elif value_type == "map_value":
391 return decode_dict(value_pb.map_value.fields, client)
392 else:
393 raise ValueError("Unknown ``value_type``", value_type)
394
395
396def decode_dict(value_fields, client) -> Union[dict, Vector]:
397 """Converts a protobuf map of Firestore ``Value``-s.
398
399 Args:
400 value_fields (google.protobuf.pyext._message.MessageMapContainer): A
401 protobuf map of Firestore ``Value``-s.
402 client (:class:`~google.cloud.firestore_v1.client.Client`):
403 A client that has a document factory.
404
405 Returns:
406 Dict[str, Union[NoneType, bool, int, float, datetime.datetime, \
407 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]]: A dictionary
408 of native Python values converted from the ``value_fields``.
409 """
410 value_fields_pb = getattr(value_fields, "_pb", value_fields)
411 res = {key: decode_value(value, client) for key, value in value_fields_pb.items()}
412
413 if res.get("__type__", None) == "__vector__":
414 # Vector data type is represented as mapping.
415 # {"__type__":"__vector__", "value": [1.0, 2.0, 3.0]}.
416 values = cast(Sequence[float], res["value"])
417 return Vector(values)
418
419 return res
420
421
422def get_doc_id(document_pb, expected_prefix) -> str:
423 """Parse a document ID from a document protobuf.
424
425 Args:
426 document_pb (google.cloud.firestore_v1.\
427 document.Document): A protobuf for a document that
428 was created in a ``CreateDocument`` RPC.
429 expected_prefix (str): The expected collection prefix for the
430 fully-qualified document name.
431
432 Returns:
433 str: The document ID from the protobuf.
434
435 Raises:
436 ValueError: If the name does not begin with the prefix.
437 """
438 prefix, document_id = document_pb.name.rsplit(DOCUMENT_PATH_DELIMITER, 1)
439 if prefix != expected_prefix:
440 raise ValueError(
441 "Unexpected document name",
442 document_pb.name,
443 "Expected to begin with",
444 expected_prefix,
445 )
446
447 return document_id
448
449
450_EmptyDict = transforms.Sentinel("Marker for an empty dict value")
451
452
453def extract_fields(
454 document_data, prefix_path: FieldPath, expand_dots=False
455) -> Generator[Tuple[Any, Any], Any, None]:
456 """Do depth-first walk of tree, yielding field_path, value"""
457 if not document_data:
458 yield prefix_path, _EmptyDict
459 else:
460 for key, value in sorted(document_data.items()):
461 if expand_dots:
462 sub_key = FieldPath.from_string(key)
463 else:
464 sub_key = FieldPath(key)
465
466 field_path = FieldPath(*(prefix_path.parts + sub_key.parts))
467
468 if isinstance(value, dict):
469 for s_path, s_value in extract_fields(value, field_path):
470 yield s_path, s_value
471 else:
472 yield field_path, value
473
474
475def set_field_value(document_data, field_path, value) -> None:
476 """Set a value into a document for a field_path"""
477 current = document_data
478 for element in field_path.parts[:-1]:
479 current = current.setdefault(element, {})
480 if value is _EmptyDict:
481 value = {}
482 current[field_path.parts[-1]] = value
483
484
485def get_field_value(document_data, field_path) -> Any:
486 if not field_path.parts:
487 raise ValueError("Empty path")
488
489 current = document_data
490 for element in field_path.parts[:-1]:
491 current = current[element]
492 return current[field_path.parts[-1]]
493
494
495class DocumentExtractor(object):
496 """Break document data up into actual data and transforms.
497
498 Handle special values such as ``DELETE_FIELD``, ``SERVER_TIMESTAMP``.
499
500 Args:
501 document_data (dict):
502 Property names and values to use for sending a change to
503 a document.
504 """
505
506 def __init__(self, document_data) -> None:
507 self.document_data = document_data
508 self.field_paths = []
509 self.deleted_fields = []
510 self.server_timestamps = []
511 self.array_removes = {}
512 self.array_unions = {}
513 self.increments = {}
514 self.minimums = {}
515 self.maximums = {}
516 self.set_fields: dict = {}
517 self.empty_document = False
518
519 prefix_path = FieldPath()
520 iterator = self._get_document_iterator(prefix_path)
521
522 for field_path, value in iterator:
523 if field_path == prefix_path and value is _EmptyDict:
524 self.empty_document = True
525
526 elif value is transforms.DELETE_FIELD:
527 self.deleted_fields.append(field_path)
528
529 elif value is transforms.SERVER_TIMESTAMP:
530 self.server_timestamps.append(field_path)
531
532 elif isinstance(value, transforms.ArrayRemove):
533 self.array_removes[field_path] = value.values
534
535 elif isinstance(value, transforms.ArrayUnion):
536 self.array_unions[field_path] = value.values
537
538 elif isinstance(value, transforms.Increment):
539 self.increments[field_path] = value.value
540
541 elif isinstance(value, transforms.Maximum):
542 self.maximums[field_path] = value.value
543
544 elif isinstance(value, transforms.Minimum):
545 self.minimums[field_path] = value.value
546
547 else:
548 self.field_paths.append(field_path)
549 set_field_value(self.set_fields, field_path, value)
550
551 def _get_document_iterator(
552 self, prefix_path: FieldPath
553 ) -> Generator[Tuple[Any, Any], Any, None]:
554 return extract_fields(self.document_data, prefix_path)
555
556 @property
557 def has_transforms(self):
558 return bool(
559 self.server_timestamps
560 or self.array_removes
561 or self.array_unions
562 or self.increments
563 or self.maximums
564 or self.minimums
565 )
566
567 @property
568 def transform_paths(self):
569 return sorted(
570 self.server_timestamps
571 + list(self.array_removes)
572 + list(self.array_unions)
573 + list(self.increments)
574 + list(self.maximums)
575 + list(self.minimums)
576 )
577
578 def _get_update_mask(
579 self, allow_empty_mask=False
580 ) -> Optional[types.common.DocumentMask]:
581 return None
582
583 def get_update_pb(
584 self, document_path, exists=None, allow_empty_mask=False
585 ) -> types.write.Write:
586 if exists is not None:
587 current_document = common.Precondition(exists=exists)
588 else:
589 current_document = None
590
591 update_pb = write.Write(
592 update=document.Document(
593 name=document_path, fields=encode_dict(self.set_fields)
594 ),
595 update_mask=self._get_update_mask(allow_empty_mask),
596 current_document=current_document,
597 )
598
599 return update_pb
600
601 def get_field_transform_pbs(
602 self, document_path
603 ) -> List[types.write.DocumentTransform.FieldTransform]:
604 def make_array_value(values):
605 value_list = [encode_value(element) for element in values]
606 return document.ArrayValue(values=value_list)
607
608 path_field_transforms = (
609 [
610 (
611 path,
612 write.DocumentTransform.FieldTransform(
613 field_path=path.to_api_repr(),
614 set_to_server_value=REQUEST_TIME_ENUM,
615 ),
616 )
617 for path in self.server_timestamps
618 ]
619 + [
620 (
621 path,
622 write.DocumentTransform.FieldTransform(
623 field_path=path.to_api_repr(),
624 remove_all_from_array=make_array_value(values),
625 ),
626 )
627 for path, values in self.array_removes.items()
628 ]
629 + [
630 (
631 path,
632 write.DocumentTransform.FieldTransform(
633 field_path=path.to_api_repr(),
634 append_missing_elements=make_array_value(values),
635 ),
636 )
637 for path, values in self.array_unions.items()
638 ]
639 + [
640 (
641 path,
642 write.DocumentTransform.FieldTransform(
643 field_path=path.to_api_repr(), increment=encode_value(value)
644 ),
645 )
646 for path, value in self.increments.items()
647 ]
648 + [
649 (
650 path,
651 write.DocumentTransform.FieldTransform(
652 field_path=path.to_api_repr(), maximum=encode_value(value)
653 ),
654 )
655 for path, value in self.maximums.items()
656 ]
657 + [
658 (
659 path,
660 write.DocumentTransform.FieldTransform(
661 field_path=path.to_api_repr(), minimum=encode_value(value)
662 ),
663 )
664 for path, value in self.minimums.items()
665 ]
666 )
667 return [transform for path, transform in sorted(path_field_transforms)]
668
669 def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
670 field_transforms = self.get_field_transform_pbs(document_path)
671 transform_pb = write.Write(
672 transform=write.DocumentTransform(
673 document=document_path, field_transforms=field_transforms
674 )
675 )
676 if exists is not None:
677 transform_pb._pb.current_document.CopyFrom(
678 common.Precondition(exists=exists)._pb
679 )
680
681 return transform_pb
682
683
684def pbs_for_create(document_path, document_data) -> List[types.write.Write]:
685 """Make ``Write`` protobufs for ``create()`` methods.
686
687 Args:
688 document_path (str): A fully-qualified document path.
689 document_data (dict): Property names and values to use for
690 creating a document.
691
692 Returns:
693 List[google.cloud.firestore_v1.types.Write]: One or two
694 ``Write`` protobuf instances for ``create()``.
695 """
696 extractor = DocumentExtractor(document_data)
697
698 if extractor.deleted_fields:
699 raise ValueError("Cannot apply DELETE_FIELD in a create request.")
700
701 create_pb = extractor.get_update_pb(document_path, exists=False)
702
703 if extractor.has_transforms:
704 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
705 create_pb.update_transforms.extend(field_transform_pbs)
706
707 return [create_pb]
708
709
710def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]:
711 """Make ``Write`` protobufs for ``set()`` methods.
712
713 Args:
714 document_path (str): A fully-qualified document path.
715 document_data (dict): Property names and values to use for
716 replacing a document.
717
718 Returns:
719 List[google.cloud.firestore_v1.types.Write]: One
720 or two ``Write`` protobuf instances for ``set()``.
721 """
722 extractor = DocumentExtractor(document_data)
723
724 if extractor.deleted_fields:
725 raise ValueError(
726 "Cannot apply DELETE_FIELD in a set request without "
727 "specifying 'merge=True' or 'merge=[field_paths]'."
728 )
729
730 set_pb = extractor.get_update_pb(document_path)
731
732 if extractor.has_transforms:
733 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
734 set_pb.update_transforms.extend(field_transform_pbs)
735
736 return [set_pb]
737
738
739class DocumentExtractorForMerge(DocumentExtractor):
740 """Break document data up into actual data and transforms."""
741
742 def __init__(self, document_data) -> None:
743 super(DocumentExtractorForMerge, self).__init__(document_data)
744 self.data_merge: list = []
745 self.transform_merge: list = []
746 self.merge: list = []
747
748 def _apply_merge_all(self) -> None:
749 self.data_merge = sorted(self.field_paths + self.deleted_fields)
750 # TODO: other transforms
751 self.transform_merge = self.transform_paths
752 self.merge = sorted(self.data_merge + self.transform_paths)
753
754 def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]:
755 for merge_field in merge:
756 if isinstance(merge_field, FieldPath):
757 yield merge_field
758 else:
759 yield FieldPath(*parse_field_path(merge_field))
760
761 def _normalize_merge_paths(self, merge) -> list:
762 merge_paths = sorted(self._construct_merge_paths(merge))
763
764 # Raise if any merge path is a parent of another. Leverage sorting
765 # to avoid quadratic behavior.
766 for index in range(len(merge_paths) - 1):
767 lhs, rhs = merge_paths[index], merge_paths[index + 1]
768 if lhs.eq_or_parent(rhs):
769 raise ValueError("Merge paths overlap: {}, {}".format(lhs, rhs))
770
771 for merge_path in merge_paths:
772 if merge_path in self.deleted_fields:
773 continue
774 try:
775 get_field_value(self.document_data, merge_path)
776 except KeyError:
777 raise ValueError("Invalid merge path: {}".format(merge_path))
778
779 return merge_paths
780
781 def _apply_merge_paths(self, merge) -> None:
782 if self.empty_document:
783 raise ValueError("Cannot merge specific fields with empty document.")
784
785 merge_paths = self._normalize_merge_paths(merge)
786
787 del self.data_merge[:]
788 del self.transform_merge[:]
789 self.merge = merge_paths
790
791 for merge_path in merge_paths:
792 if merge_path in self.transform_paths:
793 self.transform_merge.append(merge_path)
794
795 for field_path in self.field_paths:
796 if merge_path.eq_or_parent(field_path):
797 self.data_merge.append(field_path)
798
799 # Clear out data for fields not merged.
800 merged_set_fields: dict = {}
801 for field_path in self.data_merge:
802 value = get_field_value(self.document_data, field_path)
803 set_field_value(merged_set_fields, field_path, value)
804 self.set_fields = merged_set_fields
805
806 unmerged_deleted_fields = [
807 field_path
808 for field_path in self.deleted_fields
809 if field_path not in self.merge
810 ]
811 if unmerged_deleted_fields:
812 raise ValueError(
813 "Cannot delete unmerged fields: {}".format(unmerged_deleted_fields)
814 )
815 self.data_merge = sorted(self.data_merge + self.deleted_fields)
816
817 # Keep only transforms which are within merge.
818 merged_transform_paths = set()
819 for merge_path in self.merge:
820 tranform_merge_paths = [
821 transform_path
822 for transform_path in self.transform_paths
823 if merge_path.eq_or_parent(transform_path)
824 ]
825 merged_transform_paths.update(tranform_merge_paths)
826
827 self.server_timestamps = [
828 path for path in self.server_timestamps if path in merged_transform_paths
829 ]
830
831 self.array_removes = {
832 path: values
833 for path, values in self.array_removes.items()
834 if path in merged_transform_paths
835 }
836
837 self.array_unions = {
838 path: values
839 for path, values in self.array_unions.items()
840 if path in merged_transform_paths
841 }
842
843 def apply_merge(self, merge) -> None:
844 if merge is True: # merge all fields
845 self._apply_merge_all()
846 else:
847 self._apply_merge_paths(merge)
848
849 def _get_update_mask(
850 self, allow_empty_mask=False
851 ) -> Optional[types.common.DocumentMask]:
852 # Mask uses dotted / quoted paths.
853 mask_paths = [
854 field_path.to_api_repr()
855 for field_path in self.merge
856 if field_path not in self.transform_merge
857 ]
858
859 return common.DocumentMask(field_paths=mask_paths)
860
861
862def pbs_for_set_with_merge(
863 document_path, document_data, merge
864) -> List[types.write.Write]:
865 """Make ``Write`` protobufs for ``set()`` methods.
866
867 Args:
868 document_path (str): A fully-qualified document path.
869 document_data (dict): Property names and values to use for
870 replacing a document.
871 merge (Optional[bool] or Optional[List<apispec>]):
872 If True, merge all fields; else, merge only the named fields.
873
874 Returns:
875 List[google.cloud.firestore_v1.types.Write]: One
876 or two ``Write`` protobuf instances for ``set()``.
877 """
878 extractor = DocumentExtractorForMerge(document_data)
879 extractor.apply_merge(merge)
880
881 set_pb = extractor.get_update_pb(document_path)
882
883 if extractor.transform_paths:
884 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
885 set_pb.update_transforms.extend(field_transform_pbs)
886
887 return [set_pb]
888
889
890class DocumentExtractorForUpdate(DocumentExtractor):
891 """Break document data up into actual data and transforms."""
892
893 def __init__(self, document_data) -> None:
894 super(DocumentExtractorForUpdate, self).__init__(document_data)
895 self.top_level_paths = sorted(
896 [FieldPath.from_string(key) for key in document_data]
897 )
898 tops = set(self.top_level_paths)
899 for top_level_path in self.top_level_paths:
900 for ancestor in top_level_path.lineage():
901 if ancestor in tops:
902 raise ValueError(
903 "Conflicting field path: {}, {}".format(
904 top_level_path, ancestor
905 )
906 )
907
908 for field_path in self.deleted_fields:
909 if field_path not in tops:
910 raise ValueError(
911 "Cannot update with nest delete: {}".format(field_path)
912 )
913
914 def _get_document_iterator(
915 self, prefix_path: FieldPath
916 ) -> Generator[Tuple[Any, Any], Any, None]:
917 return extract_fields(self.document_data, prefix_path, expand_dots=True)
918
919 def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask:
920 mask_paths = []
921 for field_path in self.top_level_paths:
922 if field_path not in self.transform_paths:
923 mask_paths.append(field_path.to_api_repr())
924
925 return common.DocumentMask(field_paths=mask_paths)
926
927
928def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]:
929 """Make ``Write`` protobufs for ``update()`` methods.
930
931 Args:
932 document_path (str): A fully-qualified document path.
933 field_updates (dict): Field names or paths to update and values
934 to update with.
935 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
936 A write option to make assertions / preconditions on the server
937 state of the document before applying changes.
938
939 Returns:
940 List[google.cloud.firestore_v1.types.Write]: One
941 or two ``Write`` protobuf instances for ``update()``.
942 """
943 extractor = DocumentExtractorForUpdate(field_updates)
944
945 if extractor.empty_document:
946 raise ValueError("Cannot update with an empty document.")
947
948 if option is None: # Default is to use ``exists=True``.
949 option = ExistsOption(exists=True)
950
951 update_pb = extractor.get_update_pb(document_path)
952 option.modify_write(update_pb)
953
954 if extractor.has_transforms:
955 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
956 update_pb.update_transforms.extend(field_transform_pbs)
957
958 return [update_pb]
959
960
961def pb_for_delete(document_path, option) -> types.write.Write:
962 """Make a ``Write`` protobuf for ``delete()`` methods.
963
964 Args:
965 document_path (str): A fully-qualified document path.
966 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
967 A write option to make assertions / preconditions on the server
968 state of the document before applying changes.
969
970 Returns:
971 google.cloud.firestore_v1.types.Write: A
972 ``Write`` protobuf instance for the ``delete()``.
973 """
974 write_pb = write.Write(delete=document_path)
975 if option is not None:
976 option.modify_write(write_pb)
977
978 return write_pb
979
980
981class ReadAfterWriteError(Exception):
982 """Raised when a read is attempted after a write.
983
984 Raised by "read" methods that use transactions.
985 """
986
987
988def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]:
989 """Get the transaction ID from a ``Transaction`` object.
990
991 Args:
992 transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\
993 Transaction`]):
994 An existing transaction that this query will run in.
995 read_operation (Optional[bool]): Indicates if the transaction ID
996 will be used in a read operation. Defaults to :data:`True`.
997
998 Returns:
999 Optional[bytes]: The ID of the transaction, or :data:`None` if the
1000 ``transaction`` is :data:`None`.
1001
1002 Raises:
1003 ValueError: If the ``transaction`` is not in progress (only if
1004 ``transaction`` is not :data:`None`).
1005 ReadAfterWriteError: If the ``transaction`` has writes stored on
1006 it and ``read_operation`` is :data:`True`.
1007 """
1008 if transaction is None:
1009 return None
1010 else:
1011 if not transaction.in_progress:
1012 raise ValueError(INACTIVE_TXN)
1013 if read_operation and len(transaction._write_pbs) > 0:
1014 raise ReadAfterWriteError(READ_AFTER_WRITE_ERROR)
1015 return transaction.id
1016
1017
1018def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]:
1019 """Create RPC metadata containing a prefix.
1020
1021 Args:
1022 prefix (str): appropriate resource path.
1023
1024 Returns:
1025 List[Tuple[str, str]]: RPC metadata with supplied prefix
1026 """
1027 return [("google-cloud-resource-prefix", prefix)]
1028
1029
1030class WriteOption(object):
1031 """Option used to assert a condition on a write operation."""
1032
1033 def modify_write(self, write, no_create_msg=None) -> None:
1034 """Modify a ``Write`` protobuf based on the state of this write option.
1035
1036 This is a virtual method intended to be implemented by subclasses.
1037
1038 Args:
1039 write (google.cloud.firestore_v1.types.Write): A
1040 ``Write`` protobuf instance to be modified with a precondition
1041 determined by the state of this option.
1042 no_create_msg (Optional[str]): A message to use to indicate that
1043 a create operation is not allowed.
1044
1045 Raises:
1046 NotImplementedError: Always, this method is virtual.
1047 """
1048 raise NotImplementedError
1049
1050
1051class LastUpdateOption(WriteOption):
1052 """Option used to assert a "last update" condition on a write operation.
1053
1054 This will typically be created by
1055 :meth:`~google.cloud.firestore_v1.client.Client.write_option`.
1056
1057 Args:
1058 last_update_time (google.protobuf.timestamp_pb2.Timestamp): A
1059 timestamp. When set, the target document must exist and have
1060 been last updated at that time. Protobuf ``update_time`` timestamps
1061 are typically returned from methods that perform write operations
1062 as part of a "write result" protobuf or directly.
1063 """
1064
1065 def __init__(self, last_update_time) -> None:
1066 self._last_update_time = last_update_time
1067
1068 def __eq__(self, other):
1069 if not isinstance(other, self.__class__):
1070 return NotImplemented
1071 return self._last_update_time == other._last_update_time
1072
1073 def modify_write(self, write, *unused_args, **unused_kwargs) -> None:
1074 """Modify a ``Write`` protobuf based on the state of this write option.
1075
1076 The ``last_update_time`` is added to ``write_pb`` as an "update time"
1077 precondition. When set, the target document must exist and have been
1078 last updated at that time.
1079
1080 Args:
1081 write_pb (google.cloud.firestore_v1.types.Write): A
1082 ``Write`` protobuf instance to be modified with a precondition
1083 determined by the state of this option.
1084 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by
1085 other subclasses that are unused here.
1086 """
1087 current_doc = types.Precondition(update_time=self._last_update_time)
1088 write._pb.current_document.CopyFrom(current_doc._pb)
1089
1090
1091class ExistsOption(WriteOption):
1092 """Option used to assert existence on a write operation.
1093
1094 This will typically be created by
1095 :meth:`~google.cloud.firestore_v1.client.Client.write_option`.
1096
1097 Args:
1098 exists (bool): Indicates if the document being modified
1099 should already exist.
1100 """
1101
1102 def __init__(self, exists) -> None:
1103 self._exists = exists
1104
1105 def __eq__(self, other):
1106 if not isinstance(other, self.__class__):
1107 return NotImplemented
1108 return self._exists == other._exists
1109
1110 def modify_write(self, write, *unused_args, **unused_kwargs) -> None:
1111 """Modify a ``Write`` protobuf based on the state of this write option.
1112
1113 If:
1114
1115 * ``exists=True``, adds a precondition that requires existence
1116 * ``exists=False``, adds a precondition that requires non-existence
1117
1118 Args:
1119 write (google.cloud.firestore_v1.types.Write): A
1120 ``Write`` protobuf instance to be modified with a precondition
1121 determined by the state of this option.
1122 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by
1123 other subclasses that are unused here.
1124 """
1125 current_doc = types.Precondition(exists=self._exists)
1126 write._pb.current_document.CopyFrom(current_doc._pb)
1127
1128
1129def make_retry_timeout_kwargs(
1130 retry: retries.Retry | retries.AsyncRetry | object | None, timeout: float | None
1131) -> dict:
1132 """Helper fo API methods which take optional 'retry' / 'timeout' args."""
1133 kwargs = {}
1134
1135 if retry is not gapic_v1.method.DEFAULT:
1136 kwargs["retry"] = retry
1137
1138 if timeout is not None:
1139 kwargs["timeout"] = timeout
1140
1141 return kwargs
1142
1143
1144def build_timestamp(
1145 dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None
1146) -> Timestamp:
1147 """Returns the supplied datetime (or "now") as a Timestamp"""
1148 return _datetime_to_pb_timestamp(
1149 dt or DatetimeWithNanoseconds.now(tz=datetime.timezone.utc)
1150 )
1151
1152
1153def compare_timestamps(
1154 ts1: Union[Timestamp, datetime.datetime],
1155 ts2: Union[Timestamp, datetime.datetime],
1156) -> int:
1157 ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1
1158 ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2
1159 ts1_nanos = ts1.nanos + ts1.seconds * 1e9
1160 ts2_nanos = ts2.nanos + ts2.seconds * 1e9
1161 if ts1_nanos == ts2_nanos:
1162 return 0
1163 return 1 if ts1_nanos > ts2_nanos else -1
1164
1165
1166def deserialize_bundle(
1167 serialized: Union[str, bytes],
1168 client: "google.cloud.firestore_v1.client.BaseClient",
1169) -> "google.cloud.firestore_bundle.FirestoreBundle":
1170 """Inverse operation to a `FirestoreBundle` instance's `build()` method.
1171
1172 Args:
1173 serialized (Union[str, bytes]): The result of `FirestoreBundle.build()`.
1174 Should be a list of dictionaries in string format.
1175 client (BaseClient): A connected Client instance.
1176
1177 Returns:
1178 FirestoreBundle: A bundle equivalent to that which called `build()` and
1179 initially created the `serialized` value.
1180
1181 Raises:
1182 ValueError: If any of the dictionaries in the list contain any more than
1183 one top-level key.
1184 ValueError: If any unexpected BundleElement types are encountered.
1185 ValueError: If the serialized bundle ends before expected.
1186 """
1187 from google.cloud.firestore_bundle import BundleElement, FirestoreBundle
1188
1189 # Outlines the legal transitions from one BundleElement to another.
1190 bundle_state_machine = {
1191 "__initial__": ["metadata"],
1192 "metadata": ["namedQuery", "documentMetadata", "__end__"],
1193 "namedQuery": ["namedQuery", "documentMetadata", "__end__"],
1194 "documentMetadata": ["document"],
1195 "document": ["documentMetadata", "__end__"],
1196 }
1197 allowed_next_element_types: List[str] = bundle_state_machine["__initial__"]
1198
1199 # This must be saved and added last, since we cache it to preserve timestamps,
1200 # yet must flush it whenever a new document or query is added to a bundle.
1201 # The process of deserializing a bundle uses these methods which flush a
1202 # cached metadata element, and thus, it must be the last BundleElement
1203 # added during deserialization.
1204 metadata_bundle_element: Optional[BundleElement] = None
1205
1206 bundle: Optional[FirestoreBundle] = None
1207 data: Dict
1208 for data in _parse_bundle_elements_data(serialized):
1209 # BundleElements are serialized as JSON containing one key outlining
1210 # the type, with all further data nested under that key
1211 keys: List[str] = list(data.keys())
1212
1213 if len(keys) != 1:
1214 raise ValueError("Expected serialized BundleElement with one top-level key")
1215
1216 key: str = keys[0]
1217
1218 if key not in allowed_next_element_types:
1219 raise ValueError(
1220 f"Encountered BundleElement of type {key}. "
1221 f"Expected one of {allowed_next_element_types}"
1222 )
1223
1224 # Create and add our BundleElement
1225 bundle_element: BundleElement
1226 try:
1227 bundle_element = BundleElement.from_json(json.dumps(data))
1228 except AttributeError as e:
1229 # Some bad serialization formats cannot be universally deserialized.
1230 if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER
1231 raise ValueError(
1232 "Invalid serialization of datetimes. "
1233 "Cannot deserialize Bundles created from the NodeJS SDK."
1234 )
1235 raise e # pragma: NO COVER
1236
1237 if bundle is None:
1238 # This must be the first bundle type encountered
1239 assert key == "metadata"
1240 bundle = FirestoreBundle(data[key]["id"])
1241 metadata_bundle_element = bundle_element
1242
1243 else:
1244 bundle._add_bundle_element(bundle_element, client=client, type=key)
1245
1246 # Update the allowed next BundleElement types
1247 allowed_next_element_types = bundle_state_machine[key]
1248
1249 if "__end__" not in allowed_next_element_types:
1250 raise ValueError("Unexpected end to serialized FirestoreBundle")
1251 # state machine guarantees bundle and metadata have been populated
1252 bundle = cast(FirestoreBundle, bundle)
1253 metadata_bundle_element = cast(BundleElement, metadata_bundle_element)
1254 # Now, finally add the metadata element
1255 bundle._add_bundle_element(
1256 metadata_bundle_element,
1257 client=client,
1258 type="metadata",
1259 )
1260
1261 return bundle
1262
1263
1264def _parse_bundle_elements_data(
1265 serialized: Union[str, bytes]
1266) -> Generator[Dict, None, None]:
1267 """Reads through a serialized FirestoreBundle and yields JSON chunks that
1268 were created via `BundleElement.to_json(bundle_element)`.
1269
1270 Serialized FirestoreBundle instances are length-prefixed JSON objects, and
1271 so are of the form "123{...}57{...}"
1272 To correctly and safely read a bundle, we must first detect these length
1273 prefixes, read that many bytes of data, and attempt to JSON-parse that.
1274
1275 Raises:
1276 ValueError: If a chunk of JSON ever starts without following a length
1277 prefix.
1278 """
1279 _serialized: Iterator[int] = iter(
1280 serialized if isinstance(serialized, bytes) else serialized.encode("utf-8")
1281 )
1282
1283 length_prefix: str = ""
1284 while True:
1285 byte: Optional[int] = next(_serialized, None)
1286
1287 if byte is None:
1288 return None
1289
1290 _str: str = chr(byte)
1291 if _str.isnumeric():
1292 length_prefix += _str
1293 else:
1294 if length_prefix == "":
1295 raise ValueError("Expected length prefix")
1296
1297 _length_prefix = int(length_prefix)
1298 length_prefix = ""
1299 _bytes = bytearray([byte])
1300 _counter = 1
1301 while _counter < _length_prefix:
1302 _bytes.append(next(_serialized))
1303 _counter += 1
1304
1305 yield json.loads(_bytes.decode("utf-8"))
1306
1307
1308def _get_documents_from_bundle(
1309 bundle, *, query_name: Optional[str] = None
1310) -> Generator["DocumentSnapshot", None, None]:
1311 from google.cloud.firestore_bundle.bundle import _BundledDocument
1312
1313 bundled_doc: _BundledDocument
1314 for bundled_doc in bundle.documents.values():
1315 if query_name and query_name not in bundled_doc.metadata.queries:
1316 continue
1317 yield bundled_doc.snapshot
1318
1319
1320def _get_document_from_bundle(
1321 bundle,
1322 *,
1323 document_id: str,
1324) -> Optional["DocumentSnapshot"]:
1325 bundled_doc = bundle.documents.get(document_id)
1326 if bundled_doc:
1327 return bundled_doc.snapshot
1328 else:
1329 return None