1# Copyright 2017 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Common helpers shared across Google Cloud Firestore modules."""
16from __future__ import annotations
17import datetime
18import json
19from typing import (
20 Any,
21 Dict,
22 Generator,
23 Iterator,
24 List,
25 Optional,
26 Sequence,
27 Tuple,
28 Union,
29 cast,
30 TYPE_CHECKING,
31)
32
33import grpc # type: ignore
34from google.api_core import gapic_v1
35from google.api_core import retry as retries
36from google.api_core.datetime_helpers import DatetimeWithNanoseconds
37from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
38from google.protobuf import struct_pb2
39from google.protobuf.timestamp_pb2 import Timestamp # type: ignore
40from google.type import latlng_pb2 # type: ignore
41
42import google
43from google.cloud import exceptions # type: ignore
44from google.cloud.firestore_v1 import transforms, types
45from google.cloud.firestore_v1.field_path import FieldPath, parse_field_path
46from google.cloud.firestore_v1.types import common, document, write
47from google.cloud.firestore_v1.types.write import DocumentTransform
48from google.cloud.firestore_v1.vector import Vector
49
50if TYPE_CHECKING: # pragma: NO COVER
51 from google.cloud.firestore_v1 import DocumentSnapshot
52
53_EmptyDict: transforms.Sentinel
54_GRPC_ERROR_MAPPING: dict
55
56
57BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}."
58DOCUMENT_PATH_DELIMITER = "/"
59INACTIVE_TXN = "Transaction not in progress, cannot be used in API requests."
60READ_AFTER_WRITE_ERROR = "Attempted read after write in a transaction."
61BAD_REFERENCE_ERROR = (
62 "Reference value {!r} in unexpected format, expected to be of the form "
63 "``projects/{{project}}/databases/{{database}}/"
64 "documents/{{document_path}}``."
65)
66WRONG_APP_REFERENCE = (
67 "Document {!r} does not correspond to the same database " "({!r}) as the client."
68)
69REQUEST_TIME_ENUM = DocumentTransform.FieldTransform.ServerValue.REQUEST_TIME
70_GRPC_ERROR_MAPPING = {
71 grpc.StatusCode.ALREADY_EXISTS: exceptions.Conflict,
72 grpc.StatusCode.NOT_FOUND: exceptions.NotFound,
73}
74
75
76class GeoPoint(object):
77 """Simple container for a geo point value.
78
79 Args:
80 latitude (float): Latitude of a point.
81 longitude (float): Longitude of a point.
82 """
83
84 def __init__(self, latitude, longitude) -> None:
85 self.latitude = latitude
86 self.longitude = longitude
87
88 def to_protobuf(self) -> latlng_pb2.LatLng:
89 """Convert the current object to protobuf.
90
91 Returns:
92 google.type.latlng_pb2.LatLng: The current point as a protobuf.
93 """
94 return latlng_pb2.LatLng(latitude=self.latitude, longitude=self.longitude)
95
96 def __eq__(self, other):
97 """Compare two geo points for equality.
98
99 Returns:
100 Union[bool, NotImplemented]: :data:`True` if the points compare
101 equal, else :data:`False`. (Or :data:`NotImplemented` if
102 ``other`` is not a geo point.)
103 """
104 if not isinstance(other, GeoPoint):
105 return NotImplemented
106
107 return self.latitude == other.latitude and self.longitude == other.longitude
108
109 def __ne__(self, other):
110 """Compare two geo points for inequality.
111
112 Returns:
113 Union[bool, NotImplemented]: :data:`False` if the points compare
114 equal, else :data:`True`. (Or :data:`NotImplemented` if
115 ``other`` is not a geo point.)
116 """
117 equality_val = self.__eq__(other)
118 if equality_val is NotImplemented:
119 return NotImplemented
120 else:
121 return not equality_val
122
123 def __repr__(self):
124 return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})"
125
126
127def verify_path(path, is_collection) -> None:
128 """Verifies that a ``path`` has the correct form.
129
130 Checks that all of the elements in ``path`` are strings.
131
132 Args:
133 path (Tuple[str, ...]): The components in a collection or
134 document path.
135 is_collection (bool): Indicates if the ``path`` represents
136 a document or a collection.
137
138 Raises:
139 ValueError: if
140
141 * the ``path`` is empty
142 * ``is_collection=True`` and there are an even number of elements
143 * ``is_collection=False`` and there are an odd number of elements
144 * an element is not a string
145 """
146 num_elements = len(path)
147 if num_elements == 0:
148 raise ValueError("Document or collection path cannot be empty")
149
150 if is_collection:
151 if num_elements % 2 == 0:
152 raise ValueError("A collection must have an odd number of path elements")
153
154 else:
155 if num_elements % 2 == 1:
156 raise ValueError("A document must have an even number of path elements")
157
158 for element in path:
159 if not isinstance(element, str):
160 msg = BAD_PATH_TEMPLATE.format(element, type(element))
161 raise ValueError(msg)
162
163
164def encode_value(value) -> types.document.Value:
165 """Converts a native Python value into a Firestore protobuf ``Value``.
166
167 Args:
168 value (Union[NoneType, bool, int, float, datetime.datetime, \
169 str, bytes, dict, ~google.cloud.Firestore.GeoPoint, \
170 ~google.cloud.firestore_v1.vector.Vector]): A native
171 Python value to convert to a protobuf field.
172
173 Returns:
174 ~google.cloud.firestore_v1.types.Value: A
175 value encoded as a Firestore protobuf.
176
177 Raises:
178 TypeError: If the ``value`` is not one of the accepted types.
179 """
180 if value is None:
181 return document.Value(null_value=struct_pb2.NULL_VALUE)
182
183 # Must come before int since ``bool`` is an integer subtype.
184 if isinstance(value, bool):
185 return document.Value(boolean_value=value)
186
187 if isinstance(value, int):
188 return document.Value(integer_value=value)
189
190 if isinstance(value, float):
191 return document.Value(double_value=value)
192
193 if isinstance(value, DatetimeWithNanoseconds):
194 return document.Value(timestamp_value=value.timestamp_pb())
195
196 if isinstance(value, datetime.datetime):
197 return document.Value(timestamp_value=_datetime_to_pb_timestamp(value))
198
199 if isinstance(value, str):
200 return document.Value(string_value=value)
201
202 if isinstance(value, bytes):
203 return document.Value(bytes_value=value)
204
205 # NOTE: We avoid doing an isinstance() check for a Document
206 # here to avoid import cycles.
207 document_path = getattr(value, "_document_path", None)
208 if document_path is not None:
209 return document.Value(reference_value=document_path)
210
211 if isinstance(value, GeoPoint):
212 return document.Value(geo_point_value=value.to_protobuf())
213
214 if isinstance(value, (list, tuple, set, frozenset)):
215 value_list = tuple(encode_value(element) for element in value)
216 value_pb = document.ArrayValue(values=value_list)
217 return document.Value(array_value=value_pb)
218
219 if isinstance(value, Vector):
220 return encode_value(value.to_map_value())
221
222 if isinstance(value, dict):
223 value_dict = encode_dict(value)
224 value_pb = document.MapValue(fields=value_dict)
225 return document.Value(map_value=value_pb)
226
227 raise TypeError(
228 "Cannot convert to a Firestore Value", value, "Invalid type", type(value)
229 )
230
231
232def encode_dict(values_dict) -> dict:
233 """Encode a dictionary into protobuf ``Value``-s.
234
235 Args:
236 values_dict (dict): The dictionary to encode as protobuf fields.
237
238 Returns:
239 Dict[str, ~google.cloud.firestore_v1.types.Value]: A
240 dictionary of string keys and ``Value`` protobufs as dictionary
241 values.
242 """
243 return {key: encode_value(value) for key, value in values_dict.items()}
244
245
246def document_snapshot_to_protobuf(
247 snapshot: "DocumentSnapshot",
248) -> Optional["google.cloud.firestore_v1.types.Document"]:
249 from google.cloud.firestore_v1.types import Document
250
251 if not snapshot.exists:
252 return None
253
254 return Document(
255 name=snapshot.reference._document_path,
256 fields=encode_dict(snapshot._data),
257 create_time=snapshot.create_time,
258 update_time=snapshot.update_time,
259 )
260
261
262class DocumentReferenceValue:
263 """DocumentReference path container with accessors for each relevant chunk.
264
265 Usage:
266 doc_ref_val = DocumentReferenceValue(
267 'projects/my-proj/databases/(default)/documents/my-col/my-doc',
268 )
269 assert doc_ref_val.project_name == 'my-proj'
270 assert doc_ref_val.collection_name == 'my-col'
271 assert doc_ref_val.document_id == 'my-doc'
272 assert doc_ref_val.database_name == '(default)'
273
274 Raises:
275 ValueError: If the supplied value cannot satisfy a complete path.
276 """
277
278 def __init__(self, reference_value: str):
279 self._reference_value = reference_value
280
281 # The first 5 parts are
282 # projects, {project}, databases, {database}, documents
283 parts = reference_value.split(DOCUMENT_PATH_DELIMITER)
284 if len(parts) < 7:
285 msg = BAD_REFERENCE_ERROR.format(reference_value)
286 raise ValueError(msg)
287
288 self.project_name = parts[1]
289 self.collection_name = parts[5]
290 self.database_name = parts[3]
291 self.document_id = "/".join(parts[6:])
292
293 @property
294 def full_key(self) -> str:
295 """Computed property for a DocumentReference's collection_name and
296 document Id"""
297 return "/".join([self.collection_name, self.document_id])
298
299 @property
300 def full_path(self) -> str:
301 return self._reference_value or "/".join(
302 [
303 "projects",
304 self.project_name,
305 "databases",
306 self.database_name,
307 "documents",
308 self.collection_name,
309 self.document_id,
310 ]
311 )
312
313
314def reference_value_to_document(reference_value, client) -> Any:
315 """Convert a reference value string to a document.
316
317 Args:
318 reference_value (str): A document reference value.
319 client (:class:`~google.cloud.firestore_v1.client.Client`):
320 A client that has a document factory.
321
322 Returns:
323 :class:`~google.cloud.firestore_v1.document.DocumentReference`:
324 The document corresponding to ``reference_value``.
325
326 Raises:
327 ValueError: If the ``reference_value`` is not of the expected
328 format: ``projects/{project}/databases/{database}/documents/...``.
329 ValueError: If the ``reference_value`` does not come from the same
330 project / database combination as the ``client``.
331 """
332 from google.cloud.firestore_v1.base_document import BaseDocumentReference
333
334 doc_ref_value = DocumentReferenceValue(reference_value)
335
336 document: BaseDocumentReference = client.document(doc_ref_value.full_key)
337 if document._document_path != reference_value:
338 msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string)
339 raise ValueError(msg)
340
341 return document
342
343
344def decode_value(
345 value, client
346) -> Union[
347 None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint, Vector
348]:
349 """Converts a Firestore protobuf ``Value`` to a native Python value.
350
351 Args:
352 value (google.cloud.firestore_v1.types.Value): A
353 Firestore protobuf to be decoded / parsed / converted.
354 client (:class:`~google.cloud.firestore_v1.client.Client`):
355 A client that has a document factory.
356
357 Returns:
358 Union[NoneType, bool, int, float, datetime.datetime, \
359 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]: A native
360 Python value converted from the ``value``.
361
362 Raises:
363 NotImplementedError: If the ``value_type`` is ``reference_value``.
364 ValueError: If the ``value_type`` is unknown.
365 """
366 value_pb = getattr(value, "_pb", value)
367 value_type = value_pb.WhichOneof("value_type")
368
369 if value_type == "null_value":
370 return None
371 elif value_type == "boolean_value":
372 return value_pb.boolean_value
373 elif value_type == "integer_value":
374 return value_pb.integer_value
375 elif value_type == "double_value":
376 return value_pb.double_value
377 elif value_type == "timestamp_value":
378 return DatetimeWithNanoseconds.from_timestamp_pb(value_pb.timestamp_value)
379 elif value_type == "string_value":
380 return value_pb.string_value
381 elif value_type == "bytes_value":
382 return value_pb.bytes_value
383 elif value_type == "reference_value":
384 return reference_value_to_document(value_pb.reference_value, client)
385 elif value_type == "geo_point_value":
386 return GeoPoint(
387 value_pb.geo_point_value.latitude, value_pb.geo_point_value.longitude
388 )
389 elif value_type == "array_value":
390 return [
391 decode_value(element, client) for element in value_pb.array_value.values
392 ]
393 elif value_type == "map_value":
394 return decode_dict(value_pb.map_value.fields, client)
395 else:
396 raise ValueError("Unknown ``value_type``", value_type)
397
398
399def decode_dict(value_fields, client) -> Union[dict, Vector]:
400 """Converts a protobuf map of Firestore ``Value``-s.
401
402 Args:
403 value_fields (google.protobuf.pyext._message.MessageMapContainer): A
404 protobuf map of Firestore ``Value``-s.
405 client (:class:`~google.cloud.firestore_v1.client.Client`):
406 A client that has a document factory.
407
408 Returns:
409 Dict[str, Union[NoneType, bool, int, float, datetime.datetime, \
410 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]]: A dictionary
411 of native Python values converted from the ``value_fields``.
412 """
413 value_fields_pb = getattr(value_fields, "_pb", value_fields)
414 res = {key: decode_value(value, client) for key, value in value_fields_pb.items()}
415
416 if res.get("__type__", None) == "__vector__":
417 # Vector data type is represented as mapping.
418 # {"__type__":"__vector__", "value": [1.0, 2.0, 3.0]}.
419 values = cast(Sequence[float], res["value"])
420 return Vector(values)
421
422 return res
423
424
425def get_doc_id(document_pb, expected_prefix) -> str:
426 """Parse a document ID from a document protobuf.
427
428 Args:
429 document_pb (google.cloud.firestore_v1.\
430 document.Document): A protobuf for a document that
431 was created in a ``CreateDocument`` RPC.
432 expected_prefix (str): The expected collection prefix for the
433 fully-qualified document name.
434
435 Returns:
436 str: The document ID from the protobuf.
437
438 Raises:
439 ValueError: If the name does not begin with the prefix.
440 """
441 prefix, document_id = document_pb.name.rsplit(DOCUMENT_PATH_DELIMITER, 1)
442 if prefix != expected_prefix:
443 raise ValueError(
444 "Unexpected document name",
445 document_pb.name,
446 "Expected to begin with",
447 expected_prefix,
448 )
449
450 return document_id
451
452
453_EmptyDict = transforms.Sentinel("Marker for an empty dict value")
454
455
456def extract_fields(
457 document_data, prefix_path: FieldPath, expand_dots=False
458) -> Generator[Tuple[Any, Any], Any, None]:
459 """Do depth-first walk of tree, yielding field_path, value"""
460 if not document_data:
461 yield prefix_path, _EmptyDict
462 else:
463 for key, value in sorted(document_data.items()):
464 if expand_dots:
465 sub_key = FieldPath.from_string(key)
466 else:
467 sub_key = FieldPath(key)
468
469 field_path = FieldPath(*(prefix_path.parts + sub_key.parts))
470
471 if isinstance(value, dict):
472 for s_path, s_value in extract_fields(value, field_path):
473 yield s_path, s_value
474 else:
475 yield field_path, value
476
477
478def set_field_value(document_data, field_path, value) -> None:
479 """Set a value into a document for a field_path"""
480 current = document_data
481 for element in field_path.parts[:-1]:
482 current = current.setdefault(element, {})
483 if value is _EmptyDict:
484 value = {}
485 current[field_path.parts[-1]] = value
486
487
488def get_field_value(document_data, field_path) -> Any:
489 if not field_path.parts:
490 raise ValueError("Empty path")
491
492 current = document_data
493 for element in field_path.parts[:-1]:
494 current = current[element]
495 return current[field_path.parts[-1]]
496
497
498class DocumentExtractor(object):
499 """Break document data up into actual data and transforms.
500
501 Handle special values such as ``DELETE_FIELD``, ``SERVER_TIMESTAMP``.
502
503 Args:
504 document_data (dict):
505 Property names and values to use for sending a change to
506 a document.
507 """
508
509 def __init__(self, document_data) -> None:
510 self.document_data = document_data
511 self.field_paths = []
512 self.deleted_fields = []
513 self.server_timestamps = []
514 self.array_removes = {}
515 self.array_unions = {}
516 self.increments = {}
517 self.minimums = {}
518 self.maximums = {}
519 self.set_fields: dict = {}
520 self.empty_document = False
521
522 prefix_path = FieldPath()
523 iterator = self._get_document_iterator(prefix_path)
524
525 for field_path, value in iterator:
526 if field_path == prefix_path and value is _EmptyDict:
527 self.empty_document = True
528
529 elif value is transforms.DELETE_FIELD:
530 self.deleted_fields.append(field_path)
531
532 elif value is transforms.SERVER_TIMESTAMP:
533 self.server_timestamps.append(field_path)
534
535 elif isinstance(value, transforms.ArrayRemove):
536 self.array_removes[field_path] = value.values
537
538 elif isinstance(value, transforms.ArrayUnion):
539 self.array_unions[field_path] = value.values
540
541 elif isinstance(value, transforms.Increment):
542 self.increments[field_path] = value.value
543
544 elif isinstance(value, transforms.Maximum):
545 self.maximums[field_path] = value.value
546
547 elif isinstance(value, transforms.Minimum):
548 self.minimums[field_path] = value.value
549
550 else:
551 self.field_paths.append(field_path)
552 set_field_value(self.set_fields, field_path, value)
553
554 def _get_document_iterator(
555 self, prefix_path: FieldPath
556 ) -> Generator[Tuple[Any, Any], Any, None]:
557 return extract_fields(self.document_data, prefix_path)
558
559 @property
560 def has_transforms(self):
561 return bool(
562 self.server_timestamps
563 or self.array_removes
564 or self.array_unions
565 or self.increments
566 or self.maximums
567 or self.minimums
568 )
569
570 @property
571 def transform_paths(self):
572 return sorted(
573 self.server_timestamps
574 + list(self.array_removes)
575 + list(self.array_unions)
576 + list(self.increments)
577 + list(self.maximums)
578 + list(self.minimums)
579 )
580
581 def _get_update_mask(
582 self, allow_empty_mask=False
583 ) -> Optional[types.common.DocumentMask]:
584 return None
585
586 def get_update_pb(
587 self, document_path, exists=None, allow_empty_mask=False
588 ) -> types.write.Write:
589 if exists is not None:
590 current_document = common.Precondition(exists=exists)
591 else:
592 current_document = None
593
594 update_pb = write.Write(
595 update=document.Document(
596 name=document_path, fields=encode_dict(self.set_fields)
597 ),
598 update_mask=self._get_update_mask(allow_empty_mask),
599 current_document=current_document,
600 )
601
602 return update_pb
603
604 def get_field_transform_pbs(
605 self, document_path
606 ) -> List[types.write.DocumentTransform.FieldTransform]:
607 def make_array_value(values):
608 value_list = [encode_value(element) for element in values]
609 return document.ArrayValue(values=value_list)
610
611 path_field_transforms = (
612 [
613 (
614 path,
615 write.DocumentTransform.FieldTransform(
616 field_path=path.to_api_repr(),
617 set_to_server_value=REQUEST_TIME_ENUM,
618 ),
619 )
620 for path in self.server_timestamps
621 ]
622 + [
623 (
624 path,
625 write.DocumentTransform.FieldTransform(
626 field_path=path.to_api_repr(),
627 remove_all_from_array=make_array_value(values),
628 ),
629 )
630 for path, values in self.array_removes.items()
631 ]
632 + [
633 (
634 path,
635 write.DocumentTransform.FieldTransform(
636 field_path=path.to_api_repr(),
637 append_missing_elements=make_array_value(values),
638 ),
639 )
640 for path, values in self.array_unions.items()
641 ]
642 + [
643 (
644 path,
645 write.DocumentTransform.FieldTransform(
646 field_path=path.to_api_repr(), increment=encode_value(value)
647 ),
648 )
649 for path, value in self.increments.items()
650 ]
651 + [
652 (
653 path,
654 write.DocumentTransform.FieldTransform(
655 field_path=path.to_api_repr(), maximum=encode_value(value)
656 ),
657 )
658 for path, value in self.maximums.items()
659 ]
660 + [
661 (
662 path,
663 write.DocumentTransform.FieldTransform(
664 field_path=path.to_api_repr(), minimum=encode_value(value)
665 ),
666 )
667 for path, value in self.minimums.items()
668 ]
669 )
670 return [transform for path, transform in sorted(path_field_transforms)]
671
672 def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
673 field_transforms = self.get_field_transform_pbs(document_path)
674 transform_pb = write.Write(
675 transform=write.DocumentTransform(
676 document=document_path, field_transforms=field_transforms
677 )
678 )
679 if exists is not None:
680 transform_pb._pb.current_document.CopyFrom(
681 common.Precondition(exists=exists)._pb
682 )
683
684 return transform_pb
685
686
687def pbs_for_create(document_path, document_data) -> List[types.write.Write]:
688 """Make ``Write`` protobufs for ``create()`` methods.
689
690 Args:
691 document_path (str): A fully-qualified document path.
692 document_data (dict): Property names and values to use for
693 creating a document.
694
695 Returns:
696 List[google.cloud.firestore_v1.types.Write]: One or two
697 ``Write`` protobuf instances for ``create()``.
698 """
699 extractor = DocumentExtractor(document_data)
700
701 if extractor.deleted_fields:
702 raise ValueError("Cannot apply DELETE_FIELD in a create request.")
703
704 create_pb = extractor.get_update_pb(document_path, exists=False)
705
706 if extractor.has_transforms:
707 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
708 create_pb.update_transforms.extend(field_transform_pbs)
709
710 return [create_pb]
711
712
713def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]:
714 """Make ``Write`` protobufs for ``set()`` methods.
715
716 Args:
717 document_path (str): A fully-qualified document path.
718 document_data (dict): Property names and values to use for
719 replacing a document.
720
721 Returns:
722 List[google.cloud.firestore_v1.types.Write]: One
723 or two ``Write`` protobuf instances for ``set()``.
724 """
725 extractor = DocumentExtractor(document_data)
726
727 if extractor.deleted_fields:
728 raise ValueError(
729 "Cannot apply DELETE_FIELD in a set request without "
730 "specifying 'merge=True' or 'merge=[field_paths]'."
731 )
732
733 set_pb = extractor.get_update_pb(document_path)
734
735 if extractor.has_transforms:
736 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
737 set_pb.update_transforms.extend(field_transform_pbs)
738
739 return [set_pb]
740
741
742class DocumentExtractorForMerge(DocumentExtractor):
743 """Break document data up into actual data and transforms."""
744
745 def __init__(self, document_data) -> None:
746 super(DocumentExtractorForMerge, self).__init__(document_data)
747 self.data_merge: list = []
748 self.transform_merge: list = []
749 self.merge: list = []
750
751 def _apply_merge_all(self) -> None:
752 self.data_merge = sorted(self.field_paths + self.deleted_fields)
753 # TODO: other transforms
754 self.transform_merge = self.transform_paths
755 self.merge = sorted(self.data_merge + self.transform_paths)
756
757 def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]:
758 for merge_field in merge:
759 if isinstance(merge_field, FieldPath):
760 yield merge_field
761 else:
762 yield FieldPath(*parse_field_path(merge_field))
763
764 def _normalize_merge_paths(self, merge) -> list:
765 merge_paths = sorted(self._construct_merge_paths(merge))
766
767 # Raise if any merge path is a parent of another. Leverage sorting
768 # to avoid quadratic behavior.
769 for index in range(len(merge_paths) - 1):
770 lhs, rhs = merge_paths[index], merge_paths[index + 1]
771 if lhs.eq_or_parent(rhs):
772 raise ValueError("Merge paths overlap: {}, {}".format(lhs, rhs))
773
774 for merge_path in merge_paths:
775 if merge_path in self.deleted_fields:
776 continue
777 try:
778 get_field_value(self.document_data, merge_path)
779 except KeyError:
780 raise ValueError("Invalid merge path: {}".format(merge_path))
781
782 return merge_paths
783
784 def _apply_merge_paths(self, merge) -> None:
785 if self.empty_document:
786 raise ValueError("Cannot merge specific fields with empty document.")
787
788 merge_paths = self._normalize_merge_paths(merge)
789
790 del self.data_merge[:]
791 del self.transform_merge[:]
792 self.merge = merge_paths
793
794 for merge_path in merge_paths:
795 if merge_path in self.transform_paths:
796 self.transform_merge.append(merge_path)
797
798 for field_path in self.field_paths:
799 if merge_path.eq_or_parent(field_path):
800 self.data_merge.append(field_path)
801
802 # Clear out data for fields not merged.
803 merged_set_fields: dict = {}
804 for field_path in self.data_merge:
805 value = get_field_value(self.document_data, field_path)
806 set_field_value(merged_set_fields, field_path, value)
807 self.set_fields = merged_set_fields
808
809 unmerged_deleted_fields = [
810 field_path
811 for field_path in self.deleted_fields
812 if field_path not in self.merge
813 ]
814 if unmerged_deleted_fields:
815 raise ValueError(
816 "Cannot delete unmerged fields: {}".format(unmerged_deleted_fields)
817 )
818 self.data_merge = sorted(self.data_merge + self.deleted_fields)
819
820 # Keep only transforms which are within merge.
821 merged_transform_paths = set()
822 for merge_path in self.merge:
823 tranform_merge_paths = [
824 transform_path
825 for transform_path in self.transform_paths
826 if merge_path.eq_or_parent(transform_path)
827 ]
828 merged_transform_paths.update(tranform_merge_paths)
829
830 self.server_timestamps = [
831 path for path in self.server_timestamps if path in merged_transform_paths
832 ]
833
834 self.array_removes = {
835 path: values
836 for path, values in self.array_removes.items()
837 if path in merged_transform_paths
838 }
839
840 self.array_unions = {
841 path: values
842 for path, values in self.array_unions.items()
843 if path in merged_transform_paths
844 }
845
846 def apply_merge(self, merge) -> None:
847 if merge is True: # merge all fields
848 self._apply_merge_all()
849 else:
850 self._apply_merge_paths(merge)
851
852 def _get_update_mask(
853 self, allow_empty_mask=False
854 ) -> Optional[types.common.DocumentMask]:
855 # Mask uses dotted / quoted paths.
856 mask_paths = [
857 field_path.to_api_repr()
858 for field_path in self.merge
859 if field_path not in self.transform_merge
860 ]
861
862 return common.DocumentMask(field_paths=mask_paths)
863
864
865def pbs_for_set_with_merge(
866 document_path, document_data, merge
867) -> List[types.write.Write]:
868 """Make ``Write`` protobufs for ``set()`` methods.
869
870 Args:
871 document_path (str): A fully-qualified document path.
872 document_data (dict): Property names and values to use for
873 replacing a document.
874 merge (Optional[bool] or Optional[List<apispec>]):
875 If True, merge all fields; else, merge only the named fields.
876
877 Returns:
878 List[google.cloud.firestore_v1.types.Write]: One
879 or two ``Write`` protobuf instances for ``set()``.
880 """
881 extractor = DocumentExtractorForMerge(document_data)
882 extractor.apply_merge(merge)
883
884 set_pb = extractor.get_update_pb(document_path)
885
886 if extractor.transform_paths:
887 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
888 set_pb.update_transforms.extend(field_transform_pbs)
889
890 return [set_pb]
891
892
893class DocumentExtractorForUpdate(DocumentExtractor):
894 """Break document data up into actual data and transforms."""
895
896 def __init__(self, document_data) -> None:
897 super(DocumentExtractorForUpdate, self).__init__(document_data)
898 self.top_level_paths = sorted(
899 [FieldPath.from_string(key) for key in document_data]
900 )
901 tops = set(self.top_level_paths)
902 for top_level_path in self.top_level_paths:
903 for ancestor in top_level_path.lineage():
904 if ancestor in tops:
905 raise ValueError(
906 "Conflicting field path: {}, {}".format(
907 top_level_path, ancestor
908 )
909 )
910
911 for field_path in self.deleted_fields:
912 if field_path not in tops:
913 raise ValueError(
914 "Cannot update with nest delete: {}".format(field_path)
915 )
916
917 def _get_document_iterator(
918 self, prefix_path: FieldPath
919 ) -> Generator[Tuple[Any, Any], Any, None]:
920 return extract_fields(self.document_data, prefix_path, expand_dots=True)
921
922 def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask:
923 mask_paths = []
924 for field_path in self.top_level_paths:
925 if field_path not in self.transform_paths:
926 mask_paths.append(field_path.to_api_repr())
927
928 return common.DocumentMask(field_paths=mask_paths)
929
930
931def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]:
932 """Make ``Write`` protobufs for ``update()`` methods.
933
934 Args:
935 document_path (str): A fully-qualified document path.
936 field_updates (dict): Field names or paths to update and values
937 to update with.
938 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
939 A write option to make assertions / preconditions on the server
940 state of the document before applying changes.
941
942 Returns:
943 List[google.cloud.firestore_v1.types.Write]: One
944 or two ``Write`` protobuf instances for ``update()``.
945 """
946 extractor = DocumentExtractorForUpdate(field_updates)
947
948 if extractor.empty_document:
949 raise ValueError("Cannot update with an empty document.")
950
951 if option is None: # Default is to use ``exists=True``.
952 option = ExistsOption(exists=True)
953
954 update_pb = extractor.get_update_pb(document_path)
955 option.modify_write(update_pb)
956
957 if extractor.has_transforms:
958 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
959 update_pb.update_transforms.extend(field_transform_pbs)
960
961 return [update_pb]
962
963
964def pb_for_delete(document_path, option) -> types.write.Write:
965 """Make a ``Write`` protobuf for ``delete()`` methods.
966
967 Args:
968 document_path (str): A fully-qualified document path.
969 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
970 A write option to make assertions / preconditions on the server
971 state of the document before applying changes.
972
973 Returns:
974 google.cloud.firestore_v1.types.Write: A
975 ``Write`` protobuf instance for the ``delete()``.
976 """
977 write_pb = write.Write(delete=document_path)
978 if option is not None:
979 option.modify_write(write_pb)
980
981 return write_pb
982
983
984class ReadAfterWriteError(Exception):
985 """Raised when a read is attempted after a write.
986
987 Raised by "read" methods that use transactions.
988 """
989
990
991def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]:
992 """Get the transaction ID from a ``Transaction`` object.
993
994 Args:
995 transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\
996 Transaction`]):
997 An existing transaction that this query will run in.
998 read_operation (Optional[bool]): Indicates if the transaction ID
999 will be used in a read operation. Defaults to :data:`True`.
1000
1001 Returns:
1002 Optional[bytes]: The ID of the transaction, or :data:`None` if the
1003 ``transaction`` is :data:`None`.
1004
1005 Raises:
1006 ValueError: If the ``transaction`` is not in progress (only if
1007 ``transaction`` is not :data:`None`).
1008 ReadAfterWriteError: If the ``transaction`` has writes stored on
1009 it and ``read_operation`` is :data:`True`.
1010 """
1011 if transaction is None:
1012 return None
1013 else:
1014 if not transaction.in_progress:
1015 raise ValueError(INACTIVE_TXN)
1016 if read_operation and len(transaction._write_pbs) > 0:
1017 raise ReadAfterWriteError(READ_AFTER_WRITE_ERROR)
1018 return transaction.id
1019
1020
1021def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]:
1022 """Create RPC metadata containing a prefix.
1023
1024 Args:
1025 prefix (str): appropriate resource path.
1026
1027 Returns:
1028 List[Tuple[str, str]]: RPC metadata with supplied prefix
1029 """
1030 return [("google-cloud-resource-prefix", prefix)]
1031
1032
1033class WriteOption(object):
1034 """Option used to assert a condition on a write operation."""
1035
1036 def modify_write(self, write, no_create_msg=None) -> None:
1037 """Modify a ``Write`` protobuf based on the state of this write option.
1038
1039 This is a virtual method intended to be implemented by subclasses.
1040
1041 Args:
1042 write (google.cloud.firestore_v1.types.Write): A
1043 ``Write`` protobuf instance to be modified with a precondition
1044 determined by the state of this option.
1045 no_create_msg (Optional[str]): A message to use to indicate that
1046 a create operation is not allowed.
1047
1048 Raises:
1049 NotImplementedError: Always, this method is virtual.
1050 """
1051 raise NotImplementedError
1052
1053
1054class LastUpdateOption(WriteOption):
1055 """Option used to assert a "last update" condition on a write operation.
1056
1057 This will typically be created by
1058 :meth:`~google.cloud.firestore_v1.client.Client.write_option`.
1059
1060 Args:
1061 last_update_time (google.protobuf.timestamp_pb2.Timestamp): A
1062 timestamp. When set, the target document must exist and have
1063 been last updated at that time. Protobuf ``update_time`` timestamps
1064 are typically returned from methods that perform write operations
1065 as part of a "write result" protobuf or directly.
1066 """
1067
1068 def __init__(self, last_update_time) -> None:
1069 self._last_update_time = last_update_time
1070
1071 def __eq__(self, other):
1072 if not isinstance(other, self.__class__):
1073 return NotImplemented
1074 return self._last_update_time == other._last_update_time
1075
1076 def modify_write(self, write, *unused_args, **unused_kwargs) -> None:
1077 """Modify a ``Write`` protobuf based on the state of this write option.
1078
1079 The ``last_update_time`` is added to ``write_pb`` as an "update time"
1080 precondition. When set, the target document must exist and have been
1081 last updated at that time.
1082
1083 Args:
1084 write_pb (google.cloud.firestore_v1.types.Write): A
1085 ``Write`` protobuf instance to be modified with a precondition
1086 determined by the state of this option.
1087 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by
1088 other subclasses that are unused here.
1089 """
1090 current_doc = types.Precondition(update_time=self._last_update_time)
1091 write._pb.current_document.CopyFrom(current_doc._pb)
1092
1093
1094class ExistsOption(WriteOption):
1095 """Option used to assert existence on a write operation.
1096
1097 This will typically be created by
1098 :meth:`~google.cloud.firestore_v1.client.Client.write_option`.
1099
1100 Args:
1101 exists (bool): Indicates if the document being modified
1102 should already exist.
1103 """
1104
1105 def __init__(self, exists) -> None:
1106 self._exists = exists
1107
1108 def __eq__(self, other):
1109 if not isinstance(other, self.__class__):
1110 return NotImplemented
1111 return self._exists == other._exists
1112
1113 def modify_write(self, write, *unused_args, **unused_kwargs) -> None:
1114 """Modify a ``Write`` protobuf based on the state of this write option.
1115
1116 If:
1117
1118 * ``exists=True``, adds a precondition that requires existence
1119 * ``exists=False``, adds a precondition that requires non-existence
1120
1121 Args:
1122 write (google.cloud.firestore_v1.types.Write): A
1123 ``Write`` protobuf instance to be modified with a precondition
1124 determined by the state of this option.
1125 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by
1126 other subclasses that are unused here.
1127 """
1128 current_doc = types.Precondition(exists=self._exists)
1129 write._pb.current_document.CopyFrom(current_doc._pb)
1130
1131
1132def make_retry_timeout_kwargs(
1133 retry: retries.Retry | retries.AsyncRetry | object | None, timeout: float | None
1134) -> dict:
1135 """Helper fo API methods which take optional 'retry' / 'timeout' args."""
1136 kwargs = {}
1137
1138 if retry is not gapic_v1.method.DEFAULT:
1139 kwargs["retry"] = retry
1140
1141 if timeout is not None:
1142 kwargs["timeout"] = timeout
1143
1144 return kwargs
1145
1146
1147def build_timestamp(
1148 dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None
1149) -> Timestamp:
1150 """Returns the supplied datetime (or "now") as a Timestamp"""
1151 return _datetime_to_pb_timestamp(
1152 dt or DatetimeWithNanoseconds.now(tz=datetime.timezone.utc)
1153 )
1154
1155
1156def compare_timestamps(
1157 ts1: Union[Timestamp, datetime.datetime],
1158 ts2: Union[Timestamp, datetime.datetime],
1159) -> int:
1160 ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1
1161 ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2
1162 ts1_nanos = ts1.nanos + ts1.seconds * 1e9
1163 ts2_nanos = ts2.nanos + ts2.seconds * 1e9
1164 if ts1_nanos == ts2_nanos:
1165 return 0
1166 return 1 if ts1_nanos > ts2_nanos else -1
1167
1168
1169def deserialize_bundle(
1170 serialized: Union[str, bytes],
1171 client: "google.cloud.firestore_v1.client.BaseClient",
1172) -> "google.cloud.firestore_bundle.FirestoreBundle":
1173 """Inverse operation to a `FirestoreBundle` instance's `build()` method.
1174
1175 Args:
1176 serialized (Union[str, bytes]): The result of `FirestoreBundle.build()`.
1177 Should be a list of dictionaries in string format.
1178 client (BaseClient): A connected Client instance.
1179
1180 Returns:
1181 FirestoreBundle: A bundle equivalent to that which called `build()` and
1182 initially created the `serialized` value.
1183
1184 Raises:
1185 ValueError: If any of the dictionaries in the list contain any more than
1186 one top-level key.
1187 ValueError: If any unexpected BundleElement types are encountered.
1188 ValueError: If the serialized bundle ends before expected.
1189 """
1190 from google.cloud.firestore_bundle import BundleElement, FirestoreBundle
1191
1192 # Outlines the legal transitions from one BundleElement to another.
1193 bundle_state_machine = {
1194 "__initial__": ["metadata"],
1195 "metadata": ["namedQuery", "documentMetadata", "__end__"],
1196 "namedQuery": ["namedQuery", "documentMetadata", "__end__"],
1197 "documentMetadata": ["document"],
1198 "document": ["documentMetadata", "__end__"],
1199 }
1200 allowed_next_element_types: List[str] = bundle_state_machine["__initial__"]
1201
1202 # This must be saved and added last, since we cache it to preserve timestamps,
1203 # yet must flush it whenever a new document or query is added to a bundle.
1204 # The process of deserializing a bundle uses these methods which flush a
1205 # cached metadata element, and thus, it must be the last BundleElement
1206 # added during deserialization.
1207 metadata_bundle_element: Optional[BundleElement] = None
1208
1209 bundle: Optional[FirestoreBundle] = None
1210 data: Dict
1211 for data in _parse_bundle_elements_data(serialized):
1212 # BundleElements are serialized as JSON containing one key outlining
1213 # the type, with all further data nested under that key
1214 keys: List[str] = list(data.keys())
1215
1216 if len(keys) != 1:
1217 raise ValueError("Expected serialized BundleElement with one top-level key")
1218
1219 key: str = keys[0]
1220
1221 if key not in allowed_next_element_types:
1222 raise ValueError(
1223 f"Encountered BundleElement of type {key}. "
1224 f"Expected one of {allowed_next_element_types}"
1225 )
1226
1227 # Create and add our BundleElement
1228 bundle_element: BundleElement
1229 try:
1230 bundle_element = BundleElement.from_json(json.dumps(data))
1231 except AttributeError as e:
1232 # Some bad serialization formats cannot be universally deserialized.
1233 if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER
1234 raise ValueError(
1235 "Invalid serialization of datetimes. "
1236 "Cannot deserialize Bundles created from the NodeJS SDK."
1237 )
1238 raise e # pragma: NO COVER
1239
1240 if bundle is None:
1241 # This must be the first bundle type encountered
1242 assert key == "metadata"
1243 bundle = FirestoreBundle(data[key]["id"])
1244 metadata_bundle_element = bundle_element
1245
1246 else:
1247 bundle._add_bundle_element(bundle_element, client=client, type=key)
1248
1249 # Update the allowed next BundleElement types
1250 allowed_next_element_types = bundle_state_machine[key]
1251
1252 if "__end__" not in allowed_next_element_types:
1253 raise ValueError("Unexpected end to serialized FirestoreBundle")
1254 # state machine guarantees bundle and metadata have been populated
1255 bundle = cast(FirestoreBundle, bundle)
1256 metadata_bundle_element = cast(BundleElement, metadata_bundle_element)
1257 # Now, finally add the metadata element
1258 bundle._add_bundle_element(
1259 metadata_bundle_element,
1260 client=client,
1261 type="metadata",
1262 )
1263
1264 return bundle
1265
1266
1267def _parse_bundle_elements_data(
1268 serialized: Union[str, bytes]
1269) -> Generator[Dict, None, None]:
1270 """Reads through a serialized FirestoreBundle and yields JSON chunks that
1271 were created via `BundleElement.to_json(bundle_element)`.
1272
1273 Serialized FirestoreBundle instances are length-prefixed JSON objects, and
1274 so are of the form "123{...}57{...}"
1275 To correctly and safely read a bundle, we must first detect these length
1276 prefixes, read that many bytes of data, and attempt to JSON-parse that.
1277
1278 Raises:
1279 ValueError: If a chunk of JSON ever starts without following a length
1280 prefix.
1281 """
1282 _serialized: Iterator[int] = iter(
1283 serialized if isinstance(serialized, bytes) else serialized.encode("utf-8")
1284 )
1285
1286 length_prefix: str = ""
1287 while True:
1288 byte: Optional[int] = next(_serialized, None)
1289
1290 if byte is None:
1291 return None
1292
1293 _str: str = chr(byte)
1294 if _str.isnumeric():
1295 length_prefix += _str
1296 else:
1297 if length_prefix == "":
1298 raise ValueError("Expected length prefix")
1299
1300 _length_prefix = int(length_prefix)
1301 length_prefix = ""
1302 _bytes = bytearray([byte])
1303 _counter = 1
1304 while _counter < _length_prefix:
1305 _bytes.append(next(_serialized))
1306 _counter += 1
1307
1308 yield json.loads(_bytes.decode("utf-8"))
1309
1310
1311def _get_documents_from_bundle(
1312 bundle, *, query_name: Optional[str] = None
1313) -> Generator["DocumentSnapshot", None, None]:
1314 from google.cloud.firestore_bundle.bundle import _BundledDocument
1315
1316 bundled_doc: _BundledDocument
1317 for bundled_doc in bundle.documents.values():
1318 if query_name and query_name not in bundled_doc.metadata.queries:
1319 continue
1320 yield bundled_doc.snapshot
1321
1322
1323def _get_document_from_bundle(
1324 bundle,
1325 *,
1326 document_id: str,
1327) -> Optional["DocumentSnapshot"]:
1328 bundled_doc = bundle.documents.get(document_id)
1329 if bundled_doc:
1330 return bundled_doc.snapshot
1331 else:
1332 return None