Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/cloud/firestore_v1/_helpers.py: 23%
479 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-09 06:27 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-09 06:27 +0000
1# Copyright 2017 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""Common helpers shared across Google Cloud Firestore modules."""
17import datetime
18import json
20import google
21from google.api_core.datetime_helpers import DatetimeWithNanoseconds
22from google.api_core import gapic_v1
23from google.protobuf import struct_pb2
24from google.type import latlng_pb2 # type: ignore
25import grpc # type: ignore
27from google.cloud import exceptions # type: ignore
28from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore
29from google.cloud.firestore_v1.types.write import DocumentTransform
30from google.cloud.firestore_v1 import transforms
31from google.cloud.firestore_v1 import types
32from google.cloud.firestore_v1.field_path import FieldPath
33from google.cloud.firestore_v1.field_path import parse_field_path
34from google.cloud.firestore_v1.types import common
35from google.cloud.firestore_v1.types import document
36from google.cloud.firestore_v1.types import write
37from google.protobuf.timestamp_pb2 import Timestamp # type: ignore
38from typing import (
39 Any,
40 Dict,
41 Generator,
42 Iterator,
43 List,
44 NoReturn,
45 Optional,
46 Tuple,
47 Union,
48)
50_EmptyDict: transforms.Sentinel
51_GRPC_ERROR_MAPPING: dict
54BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}."
55DOCUMENT_PATH_DELIMITER = "/"
56INACTIVE_TXN = "Transaction not in progress, cannot be used in API requests."
57READ_AFTER_WRITE_ERROR = "Attempted read after write in a transaction."
58BAD_REFERENCE_ERROR = (
59 "Reference value {!r} in unexpected format, expected to be of the form "
60 "``projects/{{project}}/databases/{{database}}/"
61 "documents/{{document_path}}``."
62)
63WRONG_APP_REFERENCE = (
64 "Document {!r} does not correspond to the same database " "({!r}) as the client."
65)
66REQUEST_TIME_ENUM = DocumentTransform.FieldTransform.ServerValue.REQUEST_TIME
67_GRPC_ERROR_MAPPING = {
68 grpc.StatusCode.ALREADY_EXISTS: exceptions.Conflict,
69 grpc.StatusCode.NOT_FOUND: exceptions.NotFound,
70}
73class GeoPoint(object):
74 """Simple container for a geo point value.
76 Args:
77 latitude (float): Latitude of a point.
78 longitude (float): Longitude of a point.
79 """
81 def __init__(self, latitude, longitude) -> None:
82 self.latitude = latitude
83 self.longitude = longitude
85 def to_protobuf(self) -> latlng_pb2.LatLng:
86 """Convert the current object to protobuf.
88 Returns:
89 google.type.latlng_pb2.LatLng: The current point as a protobuf.
90 """
91 return latlng_pb2.LatLng(latitude=self.latitude, longitude=self.longitude)
93 def __eq__(self, other):
94 """Compare two geo points for equality.
96 Returns:
97 Union[bool, NotImplemented]: :data:`True` if the points compare
98 equal, else :data:`False`. (Or :data:`NotImplemented` if
99 ``other`` is not a geo point.)
100 """
101 if not isinstance(other, GeoPoint):
102 return NotImplemented
104 return self.latitude == other.latitude and self.longitude == other.longitude
106 def __ne__(self, other):
107 """Compare two geo points for inequality.
109 Returns:
110 Union[bool, NotImplemented]: :data:`False` if the points compare
111 equal, else :data:`True`. (Or :data:`NotImplemented` if
112 ``other`` is not a geo point.)
113 """
114 equality_val = self.__eq__(other)
115 if equality_val is NotImplemented:
116 return NotImplemented
117 else:
118 return not equality_val
121def verify_path(path, is_collection) -> None:
122 """Verifies that a ``path`` has the correct form.
124 Checks that all of the elements in ``path`` are strings.
126 Args:
127 path (Tuple[str, ...]): The components in a collection or
128 document path.
129 is_collection (bool): Indicates if the ``path`` represents
130 a document or a collection.
132 Raises:
133 ValueError: if
135 * the ``path`` is empty
136 * ``is_collection=True`` and there are an even number of elements
137 * ``is_collection=False`` and there are an odd number of elements
138 * an element is not a string
139 """
140 num_elements = len(path)
141 if num_elements == 0:
142 raise ValueError("Document or collection path cannot be empty")
144 if is_collection:
145 if num_elements % 2 == 0:
146 raise ValueError("A collection must have an odd number of path elements")
148 else:
149 if num_elements % 2 == 1:
150 raise ValueError("A document must have an even number of path elements")
152 for element in path:
153 if not isinstance(element, str):
154 msg = BAD_PATH_TEMPLATE.format(element, type(element))
155 raise ValueError(msg)
158def encode_value(value) -> types.document.Value:
159 """Converts a native Python value into a Firestore protobuf ``Value``.
161 Args:
162 value (Union[NoneType, bool, int, float, datetime.datetime, \
163 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]): A native
164 Python value to convert to a protobuf field.
166 Returns:
167 ~google.cloud.firestore_v1.types.Value: A
168 value encoded as a Firestore protobuf.
170 Raises:
171 TypeError: If the ``value`` is not one of the accepted types.
172 """
173 if value is None:
174 return document.Value(null_value=struct_pb2.NULL_VALUE)
176 # Must come before int since ``bool`` is an integer subtype.
177 if isinstance(value, bool):
178 return document.Value(boolean_value=value)
180 if isinstance(value, int):
181 return document.Value(integer_value=value)
183 if isinstance(value, float):
184 return document.Value(double_value=value)
186 if isinstance(value, DatetimeWithNanoseconds):
187 return document.Value(timestamp_value=value.timestamp_pb())
189 if isinstance(value, datetime.datetime):
190 return document.Value(timestamp_value=_datetime_to_pb_timestamp(value))
192 if isinstance(value, str):
193 return document.Value(string_value=value)
195 if isinstance(value, bytes):
196 return document.Value(bytes_value=value)
198 # NOTE: We avoid doing an isinstance() check for a Document
199 # here to avoid import cycles.
200 document_path = getattr(value, "_document_path", None)
201 if document_path is not None:
202 return document.Value(reference_value=document_path)
204 if isinstance(value, GeoPoint):
205 return document.Value(geo_point_value=value.to_protobuf())
207 if isinstance(value, (list, tuple, set, frozenset)):
208 value_list = tuple(encode_value(element) for element in value)
209 value_pb = document.ArrayValue(values=value_list)
210 return document.Value(array_value=value_pb)
212 if isinstance(value, dict):
213 value_dict = encode_dict(value)
214 value_pb = document.MapValue(fields=value_dict)
215 return document.Value(map_value=value_pb)
217 raise TypeError(
218 "Cannot convert to a Firestore Value", value, "Invalid type", type(value)
219 )
222def encode_dict(values_dict) -> dict:
223 """Encode a dictionary into protobuf ``Value``-s.
225 Args:
226 values_dict (dict): The dictionary to encode as protobuf fields.
228 Returns:
229 Dict[str, ~google.cloud.firestore_v1.types.Value]: A
230 dictionary of string keys and ``Value`` protobufs as dictionary
231 values.
232 """
233 return {key: encode_value(value) for key, value in values_dict.items()}
236def document_snapshot_to_protobuf(snapshot: "google.cloud.firestore_v1.base_document.DocumentSnapshot") -> Optional["google.cloud.firestore_v1.types.Document"]: # type: ignore
237 from google.cloud.firestore_v1.types import Document
239 if not snapshot.exists:
240 return None
242 return Document(
243 name=snapshot.reference._document_path,
244 fields=encode_dict(snapshot._data),
245 create_time=snapshot.create_time,
246 update_time=snapshot.update_time,
247 )
250class DocumentReferenceValue:
251 """DocumentReference path container with accessors for each relevant chunk.
253 Usage:
254 doc_ref_val = DocumentReferenceValue(
255 'projects/my-proj/databases/(default)/documents/my-col/my-doc',
256 )
257 assert doc_ref_val.project_name == 'my-proj'
258 assert doc_ref_val.collection_name == 'my-col'
259 assert doc_ref_val.document_id == 'my-doc'
260 assert doc_ref_val.database_name == '(default)'
262 Raises:
263 ValueError: If the supplied value cannot satisfy a complete path.
264 """
266 def __init__(self, reference_value: str):
267 self._reference_value = reference_value
269 # The first 5 parts are
270 # projects, {project}, databases, {database}, documents
271 parts = reference_value.split(DOCUMENT_PATH_DELIMITER)
272 if len(parts) < 7:
273 msg = BAD_REFERENCE_ERROR.format(reference_value)
274 raise ValueError(msg)
276 self.project_name = parts[1]
277 self.collection_name = parts[5]
278 self.database_name = parts[3]
279 self.document_id = "/".join(parts[6:])
281 @property
282 def full_key(self) -> str:
283 """Computed property for a DocumentReference's collection_name and
284 document Id"""
285 return "/".join([self.collection_name, self.document_id])
287 @property
288 def full_path(self) -> str:
289 return self._reference_value or "/".join(
290 [
291 "projects",
292 self.project_name,
293 "databases",
294 self.database_name,
295 "documents",
296 self.collection_name,
297 self.document_id,
298 ]
299 )
302def reference_value_to_document(reference_value, client) -> Any:
303 """Convert a reference value string to a document.
305 Args:
306 reference_value (str): A document reference value.
307 client (:class:`~google.cloud.firestore_v1.client.Client`):
308 A client that has a document factory.
310 Returns:
311 :class:`~google.cloud.firestore_v1.document.DocumentReference`:
312 The document corresponding to ``reference_value``.
314 Raises:
315 ValueError: If the ``reference_value`` is not of the expected
316 format: ``projects/{project}/databases/{database}/documents/...``.
317 ValueError: If the ``reference_value`` does not come from the same
318 project / database combination as the ``client``.
319 """
320 from google.cloud.firestore_v1.base_document import BaseDocumentReference
322 doc_ref_value = DocumentReferenceValue(reference_value)
324 document: BaseDocumentReference = client.document(doc_ref_value.full_key)
325 if document._document_path != reference_value:
326 msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string)
327 raise ValueError(msg)
329 return document
332def decode_value(
333 value, client
334) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]:
335 """Converts a Firestore protobuf ``Value`` to a native Python value.
337 Args:
338 value (google.cloud.firestore_v1.types.Value): A
339 Firestore protobuf to be decoded / parsed / converted.
340 client (:class:`~google.cloud.firestore_v1.client.Client`):
341 A client that has a document factory.
343 Returns:
344 Union[NoneType, bool, int, float, datetime.datetime, \
345 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]: A native
346 Python value converted from the ``value``.
348 Raises:
349 NotImplementedError: If the ``value_type`` is ``reference_value``.
350 ValueError: If the ``value_type`` is unknown.
351 """
352 value_pb = getattr(value, "_pb", value)
353 value_type = value_pb.WhichOneof("value_type")
355 if value_type == "null_value":
356 return None
357 elif value_type == "boolean_value":
358 return value_pb.boolean_value
359 elif value_type == "integer_value":
360 return value_pb.integer_value
361 elif value_type == "double_value":
362 return value_pb.double_value
363 elif value_type == "timestamp_value":
364 return DatetimeWithNanoseconds.from_timestamp_pb(value_pb.timestamp_value)
365 elif value_type == "string_value":
366 return value_pb.string_value
367 elif value_type == "bytes_value":
368 return value_pb.bytes_value
369 elif value_type == "reference_value":
370 return reference_value_to_document(value_pb.reference_value, client)
371 elif value_type == "geo_point_value":
372 return GeoPoint(
373 value_pb.geo_point_value.latitude, value_pb.geo_point_value.longitude
374 )
375 elif value_type == "array_value":
376 return [
377 decode_value(element, client) for element in value_pb.array_value.values
378 ]
379 elif value_type == "map_value":
380 return decode_dict(value_pb.map_value.fields, client)
381 else:
382 raise ValueError("Unknown ``value_type``", value_type)
385def decode_dict(value_fields, client) -> dict:
386 """Converts a protobuf map of Firestore ``Value``-s.
388 Args:
389 value_fields (google.protobuf.pyext._message.MessageMapContainer): A
390 protobuf map of Firestore ``Value``-s.
391 client (:class:`~google.cloud.firestore_v1.client.Client`):
392 A client that has a document factory.
394 Returns:
395 Dict[str, Union[NoneType, bool, int, float, datetime.datetime, \
396 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]]: A dictionary
397 of native Python values converted from the ``value_fields``.
398 """
399 value_fields_pb = getattr(value_fields, "_pb", value_fields)
401 return {key: decode_value(value, client) for key, value in value_fields_pb.items()}
404def get_doc_id(document_pb, expected_prefix) -> str:
405 """Parse a document ID from a document protobuf.
407 Args:
408 document_pb (google.cloud.proto.firestore.v1.\
409 document.Document): A protobuf for a document that
410 was created in a ``CreateDocument`` RPC.
411 expected_prefix (str): The expected collection prefix for the
412 fully-qualified document name.
414 Returns:
415 str: The document ID from the protobuf.
417 Raises:
418 ValueError: If the name does not begin with the prefix.
419 """
420 prefix, document_id = document_pb.name.rsplit(DOCUMENT_PATH_DELIMITER, 1)
421 if prefix != expected_prefix:
422 raise ValueError(
423 "Unexpected document name",
424 document_pb.name,
425 "Expected to begin with",
426 expected_prefix,
427 )
429 return document_id
432_EmptyDict = transforms.Sentinel("Marker for an empty dict value")
435def extract_fields(
436 document_data, prefix_path: FieldPath, expand_dots=False
437) -> Generator[Tuple[Any, Any], Any, None]:
438 """Do depth-first walk of tree, yielding field_path, value"""
439 if not document_data:
440 yield prefix_path, _EmptyDict
441 else:
442 for key, value in sorted(document_data.items()):
443 if expand_dots:
444 sub_key = FieldPath.from_string(key)
445 else:
446 sub_key = FieldPath(key)
448 field_path = FieldPath(*(prefix_path.parts + sub_key.parts))
450 if isinstance(value, dict):
451 for s_path, s_value in extract_fields(value, field_path):
452 yield s_path, s_value
453 else:
454 yield field_path, value
457def set_field_value(document_data, field_path, value) -> None:
458 """Set a value into a document for a field_path"""
459 current = document_data
460 for element in field_path.parts[:-1]:
461 current = current.setdefault(element, {})
462 if value is _EmptyDict:
463 value = {}
464 current[field_path.parts[-1]] = value
467def get_field_value(document_data, field_path) -> Any:
468 if not field_path.parts:
469 raise ValueError("Empty path")
471 current = document_data
472 for element in field_path.parts[:-1]:
473 current = current[element]
474 return current[field_path.parts[-1]]
477class DocumentExtractor(object):
478 """Break document data up into actual data and transforms.
480 Handle special values such as ``DELETE_FIELD``, ``SERVER_TIMESTAMP``.
482 Args:
483 document_data (dict):
484 Property names and values to use for sending a change to
485 a document.
486 """
488 def __init__(self, document_data) -> None:
489 self.document_data = document_data
490 self.field_paths = []
491 self.deleted_fields = []
492 self.server_timestamps = []
493 self.array_removes = {}
494 self.array_unions = {}
495 self.increments = {}
496 self.minimums = {}
497 self.maximums = {}
498 self.set_fields = {}
499 self.empty_document = False
501 prefix_path = FieldPath()
502 iterator = self._get_document_iterator(prefix_path)
504 for field_path, value in iterator:
505 if field_path == prefix_path and value is _EmptyDict:
506 self.empty_document = True
508 elif value is transforms.DELETE_FIELD:
509 self.deleted_fields.append(field_path)
511 elif value is transforms.SERVER_TIMESTAMP:
512 self.server_timestamps.append(field_path)
514 elif isinstance(value, transforms.ArrayRemove):
515 self.array_removes[field_path] = value.values
517 elif isinstance(value, transforms.ArrayUnion):
518 self.array_unions[field_path] = value.values
520 elif isinstance(value, transforms.Increment):
521 self.increments[field_path] = value.value
523 elif isinstance(value, transforms.Maximum):
524 self.maximums[field_path] = value.value
526 elif isinstance(value, transforms.Minimum):
527 self.minimums[field_path] = value.value
529 else:
530 self.field_paths.append(field_path)
531 set_field_value(self.set_fields, field_path, value)
533 def _get_document_iterator(
534 self, prefix_path: FieldPath
535 ) -> Generator[Tuple[Any, Any], Any, None]:
536 return extract_fields(self.document_data, prefix_path)
538 @property
539 def has_transforms(self):
540 return bool(
541 self.server_timestamps
542 or self.array_removes
543 or self.array_unions
544 or self.increments
545 or self.maximums
546 or self.minimums
547 )
549 @property
550 def transform_paths(self):
551 return sorted(
552 self.server_timestamps
553 + list(self.array_removes)
554 + list(self.array_unions)
555 + list(self.increments)
556 + list(self.maximums)
557 + list(self.minimums)
558 )
560 def _get_update_mask(self, allow_empty_mask=False) -> None:
561 return None
563 def get_update_pb(
564 self, document_path, exists=None, allow_empty_mask=False
565 ) -> types.write.Write:
566 if exists is not None:
567 current_document = common.Precondition(exists=exists)
568 else:
569 current_document = None
571 update_pb = write.Write(
572 update=document.Document(
573 name=document_path, fields=encode_dict(self.set_fields)
574 ),
575 update_mask=self._get_update_mask(allow_empty_mask),
576 current_document=current_document,
577 )
579 return update_pb
581 def get_field_transform_pbs(
582 self, document_path
583 ) -> List[types.write.DocumentTransform.FieldTransform]:
584 def make_array_value(values):
585 value_list = [encode_value(element) for element in values]
586 return document.ArrayValue(values=value_list)
588 path_field_transforms = (
589 [
590 (
591 path,
592 write.DocumentTransform.FieldTransform(
593 field_path=path.to_api_repr(),
594 set_to_server_value=REQUEST_TIME_ENUM,
595 ),
596 )
597 for path in self.server_timestamps
598 ]
599 + [
600 (
601 path,
602 write.DocumentTransform.FieldTransform(
603 field_path=path.to_api_repr(),
604 remove_all_from_array=make_array_value(values),
605 ),
606 )
607 for path, values in self.array_removes.items()
608 ]
609 + [
610 (
611 path,
612 write.DocumentTransform.FieldTransform(
613 field_path=path.to_api_repr(),
614 append_missing_elements=make_array_value(values),
615 ),
616 )
617 for path, values in self.array_unions.items()
618 ]
619 + [
620 (
621 path,
622 write.DocumentTransform.FieldTransform(
623 field_path=path.to_api_repr(), increment=encode_value(value)
624 ),
625 )
626 for path, value in self.increments.items()
627 ]
628 + [
629 (
630 path,
631 write.DocumentTransform.FieldTransform(
632 field_path=path.to_api_repr(), maximum=encode_value(value)
633 ),
634 )
635 for path, value in self.maximums.items()
636 ]
637 + [
638 (
639 path,
640 write.DocumentTransform.FieldTransform(
641 field_path=path.to_api_repr(), minimum=encode_value(value)
642 ),
643 )
644 for path, value in self.minimums.items()
645 ]
646 )
647 return [transform for path, transform in sorted(path_field_transforms)]
649 def get_transform_pb(self, document_path, exists=None) -> types.write.Write:
650 field_transforms = self.get_field_transform_pbs(document_path)
651 transform_pb = write.Write(
652 transform=write.DocumentTransform(
653 document=document_path, field_transforms=field_transforms
654 )
655 )
656 if exists is not None:
657 transform_pb._pb.current_document.CopyFrom(
658 common.Precondition(exists=exists)._pb
659 )
661 return transform_pb
664def pbs_for_create(document_path, document_data) -> List[types.write.Write]:
665 """Make ``Write`` protobufs for ``create()`` methods.
667 Args:
668 document_path (str): A fully-qualified document path.
669 document_data (dict): Property names and values to use for
670 creating a document.
672 Returns:
673 List[google.cloud.firestore_v1.types.Write]: One or two
674 ``Write`` protobuf instances for ``create()``.
675 """
676 extractor = DocumentExtractor(document_data)
678 if extractor.deleted_fields:
679 raise ValueError("Cannot apply DELETE_FIELD in a create request.")
681 create_pb = extractor.get_update_pb(document_path, exists=False)
683 if extractor.has_transforms:
684 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
685 create_pb.update_transforms.extend(field_transform_pbs)
687 return [create_pb]
690def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]:
691 """Make ``Write`` protobufs for ``set()`` methods.
693 Args:
694 document_path (str): A fully-qualified document path.
695 document_data (dict): Property names and values to use for
696 replacing a document.
698 Returns:
699 List[google.cloud.firestore_v1.types.Write]: One
700 or two ``Write`` protobuf instances for ``set()``.
701 """
702 extractor = DocumentExtractor(document_data)
704 if extractor.deleted_fields:
705 raise ValueError(
706 "Cannot apply DELETE_FIELD in a set request without "
707 "specifying 'merge=True' or 'merge=[field_paths]'."
708 )
710 set_pb = extractor.get_update_pb(document_path)
712 if extractor.has_transforms:
713 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
714 set_pb.update_transforms.extend(field_transform_pbs)
716 return [set_pb]
719class DocumentExtractorForMerge(DocumentExtractor):
720 """Break document data up into actual data and transforms."""
722 def __init__(self, document_data) -> None:
723 super(DocumentExtractorForMerge, self).__init__(document_data)
724 self.data_merge = []
725 self.transform_merge = []
726 self.merge = []
728 def _apply_merge_all(self) -> None:
729 self.data_merge = sorted(self.field_paths + self.deleted_fields)
730 # TODO: other transforms
731 self.transform_merge = self.transform_paths
732 self.merge = sorted(self.data_merge + self.transform_paths)
734 def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]:
735 for merge_field in merge:
736 if isinstance(merge_field, FieldPath):
737 yield merge_field
738 else:
739 yield FieldPath(*parse_field_path(merge_field))
741 def _normalize_merge_paths(self, merge) -> list:
742 merge_paths = sorted(self._construct_merge_paths(merge))
744 # Raise if any merge path is a parent of another. Leverage sorting
745 # to avoid quadratic behavior.
746 for index in range(len(merge_paths) - 1):
747 lhs, rhs = merge_paths[index], merge_paths[index + 1]
748 if lhs.eq_or_parent(rhs):
749 raise ValueError("Merge paths overlap: {}, {}".format(lhs, rhs))
751 for merge_path in merge_paths:
752 if merge_path in self.deleted_fields:
753 continue
754 try:
755 get_field_value(self.document_data, merge_path)
756 except KeyError:
757 raise ValueError("Invalid merge path: {}".format(merge_path))
759 return merge_paths
761 def _apply_merge_paths(self, merge) -> None:
762 if self.empty_document:
763 raise ValueError("Cannot merge specific fields with empty document.")
765 merge_paths = self._normalize_merge_paths(merge)
767 del self.data_merge[:]
768 del self.transform_merge[:]
769 self.merge = merge_paths
771 for merge_path in merge_paths:
772 if merge_path in self.transform_paths:
773 self.transform_merge.append(merge_path)
775 for field_path in self.field_paths:
776 if merge_path.eq_or_parent(field_path):
777 self.data_merge.append(field_path)
779 # Clear out data for fields not merged.
780 merged_set_fields = {}
781 for field_path in self.data_merge:
782 value = get_field_value(self.document_data, field_path)
783 set_field_value(merged_set_fields, field_path, value)
784 self.set_fields = merged_set_fields
786 unmerged_deleted_fields = [
787 field_path
788 for field_path in self.deleted_fields
789 if field_path not in self.merge
790 ]
791 if unmerged_deleted_fields:
792 raise ValueError(
793 "Cannot delete unmerged fields: {}".format(unmerged_deleted_fields)
794 )
795 self.data_merge = sorted(self.data_merge + self.deleted_fields)
797 # Keep only transforms which are within merge.
798 merged_transform_paths = set()
799 for merge_path in self.merge:
800 tranform_merge_paths = [
801 transform_path
802 for transform_path in self.transform_paths
803 if merge_path.eq_or_parent(transform_path)
804 ]
805 merged_transform_paths.update(tranform_merge_paths)
807 self.server_timestamps = [
808 path for path in self.server_timestamps if path in merged_transform_paths
809 ]
811 self.array_removes = {
812 path: values
813 for path, values in self.array_removes.items()
814 if path in merged_transform_paths
815 }
817 self.array_unions = {
818 path: values
819 for path, values in self.array_unions.items()
820 if path in merged_transform_paths
821 }
823 def apply_merge(self, merge) -> None:
824 if merge is True: # merge all fields
825 self._apply_merge_all()
826 else:
827 self._apply_merge_paths(merge)
829 def _get_update_mask(
830 self, allow_empty_mask=False
831 ) -> Optional[types.common.DocumentMask]:
832 # Mask uses dotted / quoted paths.
833 mask_paths = [
834 field_path.to_api_repr()
835 for field_path in self.merge
836 if field_path not in self.transform_merge
837 ]
839 return common.DocumentMask(field_paths=mask_paths)
842def pbs_for_set_with_merge(
843 document_path, document_data, merge
844) -> List[types.write.Write]:
845 """Make ``Write`` protobufs for ``set()`` methods.
847 Args:
848 document_path (str): A fully-qualified document path.
849 document_data (dict): Property names and values to use for
850 replacing a document.
851 merge (Optional[bool] or Optional[List<apispec>]):
852 If True, merge all fields; else, merge only the named fields.
854 Returns:
855 List[google.cloud.firestore_v1.types.Write]: One
856 or two ``Write`` protobuf instances for ``set()``.
857 """
858 extractor = DocumentExtractorForMerge(document_data)
859 extractor.apply_merge(merge)
861 set_pb = extractor.get_update_pb(document_path)
863 if extractor.transform_paths:
864 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
865 set_pb.update_transforms.extend(field_transform_pbs)
867 return [set_pb]
870class DocumentExtractorForUpdate(DocumentExtractor):
871 """Break document data up into actual data and transforms."""
873 def __init__(self, document_data) -> None:
874 super(DocumentExtractorForUpdate, self).__init__(document_data)
875 self.top_level_paths = sorted(
876 [FieldPath.from_string(key) for key in document_data]
877 )
878 tops = set(self.top_level_paths)
879 for top_level_path in self.top_level_paths:
880 for ancestor in top_level_path.lineage():
881 if ancestor in tops:
882 raise ValueError(
883 "Conflicting field path: {}, {}".format(
884 top_level_path, ancestor
885 )
886 )
888 for field_path in self.deleted_fields:
889 if field_path not in tops:
890 raise ValueError(
891 "Cannot update with nest delete: {}".format(field_path)
892 )
894 def _get_document_iterator(
895 self, prefix_path: FieldPath
896 ) -> Generator[Tuple[Any, Any], Any, None]:
897 return extract_fields(self.document_data, prefix_path, expand_dots=True)
899 def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask:
900 mask_paths = []
901 for field_path in self.top_level_paths:
902 if field_path not in self.transform_paths:
903 mask_paths.append(field_path.to_api_repr())
905 return common.DocumentMask(field_paths=mask_paths)
908def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]:
909 """Make ``Write`` protobufs for ``update()`` methods.
911 Args:
912 document_path (str): A fully-qualified document path.
913 field_updates (dict): Field names or paths to update and values
914 to update with.
915 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
916 A write option to make assertions / preconditions on the server
917 state of the document before applying changes.
919 Returns:
920 List[google.cloud.firestore_v1.types.Write]: One
921 or two ``Write`` protobuf instances for ``update()``.
922 """
923 extractor = DocumentExtractorForUpdate(field_updates)
925 if extractor.empty_document:
926 raise ValueError("Cannot update with an empty document.")
928 if option is None: # Default is to use ``exists=True``.
929 option = ExistsOption(exists=True)
931 update_pb = extractor.get_update_pb(document_path)
932 option.modify_write(update_pb)
934 if extractor.has_transforms:
935 field_transform_pbs = extractor.get_field_transform_pbs(document_path)
936 update_pb.update_transforms.extend(field_transform_pbs)
938 return [update_pb]
941def pb_for_delete(document_path, option) -> types.write.Write:
942 """Make a ``Write`` protobuf for ``delete()`` methods.
944 Args:
945 document_path (str): A fully-qualified document path.
946 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]):
947 A write option to make assertions / preconditions on the server
948 state of the document before applying changes.
950 Returns:
951 google.cloud.firestore_v1.types.Write: A
952 ``Write`` protobuf instance for the ``delete()``.
953 """
954 write_pb = write.Write(delete=document_path)
955 if option is not None:
956 option.modify_write(write_pb)
958 return write_pb
961class ReadAfterWriteError(Exception):
962 """Raised when a read is attempted after a write.
964 Raised by "read" methods that use transactions.
965 """
968def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]:
969 """Get the transaction ID from a ``Transaction`` object.
971 Args:
972 transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\
973 Transaction`]):
974 An existing transaction that this query will run in.
975 read_operation (Optional[bool]): Indicates if the transaction ID
976 will be used in a read operation. Defaults to :data:`True`.
978 Returns:
979 Optional[bytes]: The ID of the transaction, or :data:`None` if the
980 ``transaction`` is :data:`None`.
982 Raises:
983 ValueError: If the ``transaction`` is not in progress (only if
984 ``transaction`` is not :data:`None`).
985 ReadAfterWriteError: If the ``transaction`` has writes stored on
986 it and ``read_operation`` is :data:`True`.
987 """
988 if transaction is None:
989 return None
990 else:
991 if not transaction.in_progress:
992 raise ValueError(INACTIVE_TXN)
993 if read_operation and len(transaction._write_pbs) > 0:
994 raise ReadAfterWriteError(READ_AFTER_WRITE_ERROR)
995 return transaction.id
998def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]:
999 """Create RPC metadata containing a prefix.
1001 Args:
1002 prefix (str): appropriate resource path.
1004 Returns:
1005 List[Tuple[str, str]]: RPC metadata with supplied prefix
1006 """
1007 return [("google-cloud-resource-prefix", prefix)]
1010class WriteOption(object):
1011 """Option used to assert a condition on a write operation."""
1013 def modify_write(self, write, no_create_msg=None) -> NoReturn:
1014 """Modify a ``Write`` protobuf based on the state of this write option.
1016 This is a virtual method intended to be implemented by subclasses.
1018 Args:
1019 write (google.cloud.firestore_v1.types.Write): A
1020 ``Write`` protobuf instance to be modified with a precondition
1021 determined by the state of this option.
1022 no_create_msg (Optional[str]): A message to use to indicate that
1023 a create operation is not allowed.
1025 Raises:
1026 NotImplementedError: Always, this method is virtual.
1027 """
1028 raise NotImplementedError
1031class LastUpdateOption(WriteOption):
1032 """Option used to assert a "last update" condition on a write operation.
1034 This will typically be created by
1035 :meth:`~google.cloud.firestore_v1.client.Client.write_option`.
1037 Args:
1038 last_update_time (google.protobuf.timestamp_pb2.Timestamp): A
1039 timestamp. When set, the target document must exist and have
1040 been last updated at that time. Protobuf ``update_time`` timestamps
1041 are typically returned from methods that perform write operations
1042 as part of a "write result" protobuf or directly.
1043 """
1045 def __init__(self, last_update_time) -> None:
1046 self._last_update_time = last_update_time
1048 def __eq__(self, other):
1049 if not isinstance(other, self.__class__):
1050 return NotImplemented
1051 return self._last_update_time == other._last_update_time
1053 def modify_write(self, write, **unused_kwargs) -> None:
1054 """Modify a ``Write`` protobuf based on the state of this write option.
1056 The ``last_update_time`` is added to ``write_pb`` as an "update time"
1057 precondition. When set, the target document must exist and have been
1058 last updated at that time.
1060 Args:
1061 write_pb (google.cloud.firestore_v1.types.Write): A
1062 ``Write`` protobuf instance to be modified with a precondition
1063 determined by the state of this option.
1064 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by
1065 other subclasses that are unused here.
1066 """
1067 current_doc = types.Precondition(update_time=self._last_update_time)
1068 write._pb.current_document.CopyFrom(current_doc._pb)
1071class ExistsOption(WriteOption):
1072 """Option used to assert existence on a write operation.
1074 This will typically be created by
1075 :meth:`~google.cloud.firestore_v1.client.Client.write_option`.
1077 Args:
1078 exists (bool): Indicates if the document being modified
1079 should already exist.
1080 """
1082 def __init__(self, exists) -> None:
1083 self._exists = exists
1085 def __eq__(self, other):
1086 if not isinstance(other, self.__class__):
1087 return NotImplemented
1088 return self._exists == other._exists
1090 def modify_write(self, write, **unused_kwargs) -> None:
1091 """Modify a ``Write`` protobuf based on the state of this write option.
1093 If:
1095 * ``exists=True``, adds a precondition that requires existence
1096 * ``exists=False``, adds a precondition that requires non-existence
1098 Args:
1099 write (google.cloud.firestore_v1.types.Write): A
1100 ``Write`` protobuf instance to be modified with a precondition
1101 determined by the state of this option.
1102 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by
1103 other subclasses that are unused here.
1104 """
1105 current_doc = types.Precondition(exists=self._exists)
1106 write._pb.current_document.CopyFrom(current_doc._pb)
1109def make_retry_timeout_kwargs(retry, timeout) -> dict:
1110 """Helper fo API methods which take optional 'retry' / 'timeout' args."""
1111 kwargs = {}
1113 if retry is not gapic_v1.method.DEFAULT:
1114 kwargs["retry"] = retry
1116 if timeout is not None:
1117 kwargs["timeout"] = timeout
1119 return kwargs
1122def build_timestamp(
1123 dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None
1124) -> Timestamp:
1125 """Returns the supplied datetime (or "now") as a Timestamp"""
1126 return _datetime_to_pb_timestamp(
1127 dt or DatetimeWithNanoseconds.now(tz=datetime.timezone.utc)
1128 )
1131def compare_timestamps(
1132 ts1: Union[Timestamp, datetime.datetime],
1133 ts2: Union[Timestamp, datetime.datetime],
1134) -> int:
1135 ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1
1136 ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2
1137 ts1_nanos = ts1.nanos + ts1.seconds * 1e9
1138 ts2_nanos = ts2.nanos + ts2.seconds * 1e9
1139 if ts1_nanos == ts2_nanos:
1140 return 0
1141 return 1 if ts1_nanos > ts2_nanos else -1
1144def deserialize_bundle(
1145 serialized: Union[str, bytes],
1146 client: "google.cloud.firestore_v1.client.BaseClient", # type: ignore
1147) -> "google.cloud.firestore_bundle.FirestoreBundle": # type: ignore
1148 """Inverse operation to a `FirestoreBundle` instance's `build()` method.
1150 Args:
1151 serialized (Union[str, bytes]): The result of `FirestoreBundle.build()`.
1152 Should be a list of dictionaries in string format.
1153 client (BaseClient): A connected Client instance.
1155 Returns:
1156 FirestoreBundle: A bundle equivalent to that which called `build()` and
1157 initially created the `serialized` value.
1159 Raises:
1160 ValueError: If any of the dictionaries in the list contain any more than
1161 one top-level key.
1162 ValueError: If any unexpected BundleElement types are encountered.
1163 ValueError: If the serialized bundle ends before expected.
1164 """
1165 from google.cloud.firestore_bundle import BundleElement, FirestoreBundle
1167 # Outlines the legal transitions from one BundleElement to another.
1168 bundle_state_machine = {
1169 "__initial__": ["metadata"],
1170 "metadata": ["namedQuery", "documentMetadata", "__end__"],
1171 "namedQuery": ["namedQuery", "documentMetadata", "__end__"],
1172 "documentMetadata": ["document"],
1173 "document": ["documentMetadata", "__end__"],
1174 }
1175 allowed_next_element_types: List[str] = bundle_state_machine["__initial__"]
1177 # This must be saved and added last, since we cache it to preserve timestamps,
1178 # yet must flush it whenever a new document or query is added to a bundle.
1179 # The process of deserializing a bundle uses these methods which flush a
1180 # cached metadata element, and thus, it must be the last BundleElement
1181 # added during deserialization.
1182 metadata_bundle_element: Optional[BundleElement] = None
1184 bundle: Optional[FirestoreBundle] = None
1185 data: Dict
1186 for data in _parse_bundle_elements_data(serialized):
1187 # BundleElements are serialized as JSON containing one key outlining
1188 # the type, with all further data nested under that key
1189 keys: List[str] = list(data.keys())
1191 if len(keys) != 1:
1192 raise ValueError("Expected serialized BundleElement with one top-level key")
1194 key: str = keys[0]
1196 if key not in allowed_next_element_types:
1197 raise ValueError(
1198 f"Encountered BundleElement of type {key}. "
1199 f"Expected one of {allowed_next_element_types}"
1200 )
1202 # Create and add our BundleElement
1203 bundle_element: BundleElement
1204 try:
1205 bundle_element: BundleElement = BundleElement.from_json(json.dumps(data)) # type: ignore
1206 except AttributeError as e:
1207 # Some bad serialization formats cannot be universally deserialized.
1208 if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER
1209 raise ValueError(
1210 "Invalid serialization of datetimes. "
1211 "Cannot deserialize Bundles created from the NodeJS SDK."
1212 )
1213 raise e # pragma: NO COVER
1215 if bundle is None:
1216 # This must be the first bundle type encountered
1217 assert key == "metadata"
1218 bundle = FirestoreBundle(data[key]["id"])
1219 metadata_bundle_element = bundle_element
1221 else:
1222 bundle._add_bundle_element(bundle_element, client=client, type=key)
1224 # Update the allowed next BundleElement types
1225 allowed_next_element_types = bundle_state_machine[key]
1227 if "__end__" not in allowed_next_element_types:
1228 raise ValueError("Unexpected end to serialized FirestoreBundle")
1230 # Now, finally add the metadata element
1231 bundle._add_bundle_element(
1232 metadata_bundle_element,
1233 client=client,
1234 type="metadata", # type: ignore
1235 )
1237 return bundle
1240def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict, None, None]: # type: ignore
1241 """Reads through a serialized FirestoreBundle and yields JSON chunks that
1242 were created via `BundleElement.to_json(bundle_element)`.
1244 Serialized FirestoreBundle instances are length-prefixed JSON objects, and
1245 so are of the form "123{...}57{...}"
1246 To correctly and safely read a bundle, we must first detect these length
1247 prefixes, read that many bytes of data, and attempt to JSON-parse that.
1249 Raises:
1250 ValueError: If a chunk of JSON ever starts without following a length
1251 prefix.
1252 """
1253 _serialized: Iterator[int] = iter(
1254 serialized if isinstance(serialized, bytes) else serialized.encode("utf-8")
1255 )
1257 length_prefix: str = ""
1258 while True:
1259 byte: Optional[int] = next(_serialized, None)
1261 if byte is None:
1262 return None
1264 _str: str = chr(byte)
1265 if _str.isnumeric():
1266 length_prefix += _str
1267 else:
1268 if length_prefix == "":
1269 raise ValueError("Expected length prefix")
1271 _length_prefix = int(length_prefix)
1272 length_prefix = ""
1273 _bytes = bytearray([byte])
1274 _counter = 1
1275 while _counter < _length_prefix:
1276 _bytes.append(next(_serialized))
1277 _counter += 1
1279 yield json.loads(_bytes.decode("utf-8"))
1282def _get_documents_from_bundle(
1283 bundle, *, query_name: Optional[str] = None
1284) -> Generator["google.cloud.firestore.DocumentSnapshot", None, None]: # type: ignore
1285 from google.cloud.firestore_bundle.bundle import _BundledDocument
1287 bundled_doc: _BundledDocument
1288 for bundled_doc in bundle.documents.values():
1289 if query_name and query_name not in bundled_doc.metadata.queries:
1290 continue
1291 yield bundled_doc.snapshot
1294def _get_document_from_bundle(
1295 bundle,
1296 *,
1297 document_id: str,
1298) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore
1299 bundled_doc = bundle.documents.get(document_id)
1300 if bundled_doc:
1301 return bundled_doc.snapshot