Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/cloud/firestore_v1/_helpers.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

491 statements  

1# Copyright 2017 Google LLC All rights reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Common helpers shared across Google Cloud Firestore modules.""" 

16from __future__ import annotations 

17import datetime 

18import json 

19from typing import ( 

20 Any, 

21 Dict, 

22 Generator, 

23 Iterator, 

24 List, 

25 Optional, 

26 Sequence, 

27 Tuple, 

28 Union, 

29 cast, 

30 TYPE_CHECKING, 

31) 

32 

33import grpc # type: ignore 

34from google.api_core import gapic_v1 

35from google.api_core import retry as retries 

36from google.api_core.datetime_helpers import DatetimeWithNanoseconds 

37from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore 

38from google.protobuf import struct_pb2 

39from google.protobuf.timestamp_pb2 import Timestamp # type: ignore 

40from google.type import latlng_pb2 # type: ignore 

41 

42import google 

43from google.cloud import exceptions # type: ignore 

44from google.cloud.firestore_v1 import transforms, types 

45from google.cloud.firestore_v1.field_path import FieldPath, parse_field_path 

46from google.cloud.firestore_v1.types import common, document, write 

47from google.cloud.firestore_v1.types.write import DocumentTransform 

48from google.cloud.firestore_v1.vector import Vector 

49 

50if TYPE_CHECKING: # pragma: NO COVER 

51 from google.cloud.firestore_v1 import DocumentSnapshot 

52 

53_EmptyDict: transforms.Sentinel 

54_GRPC_ERROR_MAPPING: dict 

55 

56 

57BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}." 

58DOCUMENT_PATH_DELIMITER = "/" 

59INACTIVE_TXN = "Transaction not in progress, cannot be used in API requests." 

60READ_AFTER_WRITE_ERROR = "Attempted read after write in a transaction." 

61BAD_REFERENCE_ERROR = ( 

62 "Reference value {!r} in unexpected format, expected to be of the form " 

63 "``projects/{{project}}/databases/{{database}}/" 

64 "documents/{{document_path}}``." 

65) 

66WRONG_APP_REFERENCE = ( 

67 "Document {!r} does not correspond to the same database " "({!r}) as the client." 

68) 

69REQUEST_TIME_ENUM = DocumentTransform.FieldTransform.ServerValue.REQUEST_TIME 

70_GRPC_ERROR_MAPPING = { 

71 grpc.StatusCode.ALREADY_EXISTS: exceptions.Conflict, 

72 grpc.StatusCode.NOT_FOUND: exceptions.NotFound, 

73} 

74 

75 

76class GeoPoint(object): 

77 """Simple container for a geo point value. 

78 

79 Args: 

80 latitude (float): Latitude of a point. 

81 longitude (float): Longitude of a point. 

82 """ 

83 

84 def __init__(self, latitude, longitude) -> None: 

85 self.latitude = latitude 

86 self.longitude = longitude 

87 

88 def to_protobuf(self) -> latlng_pb2.LatLng: 

89 """Convert the current object to protobuf. 

90 

91 Returns: 

92 google.type.latlng_pb2.LatLng: The current point as a protobuf. 

93 """ 

94 return latlng_pb2.LatLng(latitude=self.latitude, longitude=self.longitude) 

95 

96 def __eq__(self, other): 

97 """Compare two geo points for equality. 

98 

99 Returns: 

100 Union[bool, NotImplemented]: :data:`True` if the points compare 

101 equal, else :data:`False`. (Or :data:`NotImplemented` if 

102 ``other`` is not a geo point.) 

103 """ 

104 if not isinstance(other, GeoPoint): 

105 return NotImplemented 

106 

107 return self.latitude == other.latitude and self.longitude == other.longitude 

108 

109 def __ne__(self, other): 

110 """Compare two geo points for inequality. 

111 

112 Returns: 

113 Union[bool, NotImplemented]: :data:`False` if the points compare 

114 equal, else :data:`True`. (Or :data:`NotImplemented` if 

115 ``other`` is not a geo point.) 

116 """ 

117 equality_val = self.__eq__(other) 

118 if equality_val is NotImplemented: 

119 return NotImplemented 

120 else: 

121 return not equality_val 

122 

123 def __repr__(self): 

124 return f"{type(self).__name__}(latitude={self.latitude}, longitude={self.longitude})" 

125 

126 

127def verify_path(path, is_collection) -> None: 

128 """Verifies that a ``path`` has the correct form. 

129 

130 Checks that all of the elements in ``path`` are strings. 

131 

132 Args: 

133 path (Tuple[str, ...]): The components in a collection or 

134 document path. 

135 is_collection (bool): Indicates if the ``path`` represents 

136 a document or a collection. 

137 

138 Raises: 

139 ValueError: if 

140 

141 * the ``path`` is empty 

142 * ``is_collection=True`` and there are an even number of elements 

143 * ``is_collection=False`` and there are an odd number of elements 

144 * an element is not a string 

145 """ 

146 num_elements = len(path) 

147 if num_elements == 0: 

148 raise ValueError("Document or collection path cannot be empty") 

149 

150 if is_collection: 

151 if num_elements % 2 == 0: 

152 raise ValueError("A collection must have an odd number of path elements") 

153 

154 else: 

155 if num_elements % 2 == 1: 

156 raise ValueError("A document must have an even number of path elements") 

157 

158 for element in path: 

159 if not isinstance(element, str): 

160 msg = BAD_PATH_TEMPLATE.format(element, type(element)) 

161 raise ValueError(msg) 

162 

163 

164def encode_value(value) -> types.document.Value: 

165 """Converts a native Python value into a Firestore protobuf ``Value``. 

166 

167 Args: 

168 value (Union[NoneType, bool, int, float, datetime.datetime, \ 

169 str, bytes, dict, ~google.cloud.Firestore.GeoPoint, \ 

170 ~google.cloud.firestore_v1.vector.Vector]): A native 

171 Python value to convert to a protobuf field. 

172 

173 Returns: 

174 ~google.cloud.firestore_v1.types.Value: A 

175 value encoded as a Firestore protobuf. 

176 

177 Raises: 

178 TypeError: If the ``value`` is not one of the accepted types. 

179 """ 

180 if value is None: 

181 return document.Value(null_value=struct_pb2.NULL_VALUE) 

182 

183 # Must come before int since ``bool`` is an integer subtype. 

184 if isinstance(value, bool): 

185 return document.Value(boolean_value=value) 

186 

187 if isinstance(value, int): 

188 return document.Value(integer_value=value) 

189 

190 if isinstance(value, float): 

191 return document.Value(double_value=value) 

192 

193 if isinstance(value, DatetimeWithNanoseconds): 

194 return document.Value(timestamp_value=value.timestamp_pb()) 

195 

196 if isinstance(value, datetime.datetime): 

197 return document.Value(timestamp_value=_datetime_to_pb_timestamp(value)) 

198 

199 if isinstance(value, str): 

200 return document.Value(string_value=value) 

201 

202 if isinstance(value, bytes): 

203 return document.Value(bytes_value=value) 

204 

205 # NOTE: We avoid doing an isinstance() check for a Document 

206 # here to avoid import cycles. 

207 document_path = getattr(value, "_document_path", None) 

208 if document_path is not None: 

209 return document.Value(reference_value=document_path) 

210 

211 if isinstance(value, GeoPoint): 

212 return document.Value(geo_point_value=value.to_protobuf()) 

213 

214 if isinstance(value, (list, tuple, set, frozenset)): 

215 value_list = tuple(encode_value(element) for element in value) 

216 value_pb = document.ArrayValue(values=value_list) 

217 return document.Value(array_value=value_pb) 

218 

219 if isinstance(value, Vector): 

220 return encode_value(value.to_map_value()) 

221 

222 if isinstance(value, dict): 

223 value_dict = encode_dict(value) 

224 value_pb = document.MapValue(fields=value_dict) 

225 return document.Value(map_value=value_pb) 

226 

227 raise TypeError( 

228 "Cannot convert to a Firestore Value", value, "Invalid type", type(value) 

229 ) 

230 

231 

232def encode_dict(values_dict) -> dict: 

233 """Encode a dictionary into protobuf ``Value``-s. 

234 

235 Args: 

236 values_dict (dict): The dictionary to encode as protobuf fields. 

237 

238 Returns: 

239 Dict[str, ~google.cloud.firestore_v1.types.Value]: A 

240 dictionary of string keys and ``Value`` protobufs as dictionary 

241 values. 

242 """ 

243 return {key: encode_value(value) for key, value in values_dict.items()} 

244 

245 

246def document_snapshot_to_protobuf( 

247 snapshot: "DocumentSnapshot", 

248) -> Optional["google.cloud.firestore_v1.types.Document"]: 

249 from google.cloud.firestore_v1.types import Document 

250 

251 if not snapshot.exists: 

252 return None 

253 

254 return Document( 

255 name=snapshot.reference._document_path, 

256 fields=encode_dict(snapshot._data), 

257 create_time=snapshot.create_time, 

258 update_time=snapshot.update_time, 

259 ) 

260 

261 

262class DocumentReferenceValue: 

263 """DocumentReference path container with accessors for each relevant chunk. 

264 

265 Usage: 

266 doc_ref_val = DocumentReferenceValue( 

267 'projects/my-proj/databases/(default)/documents/my-col/my-doc', 

268 ) 

269 assert doc_ref_val.project_name == 'my-proj' 

270 assert doc_ref_val.collection_name == 'my-col' 

271 assert doc_ref_val.document_id == 'my-doc' 

272 assert doc_ref_val.database_name == '(default)' 

273 

274 Raises: 

275 ValueError: If the supplied value cannot satisfy a complete path. 

276 """ 

277 

278 def __init__(self, reference_value: str): 

279 self._reference_value = reference_value 

280 

281 # The first 5 parts are 

282 # projects, {project}, databases, {database}, documents 

283 parts = reference_value.split(DOCUMENT_PATH_DELIMITER) 

284 if len(parts) < 7: 

285 msg = BAD_REFERENCE_ERROR.format(reference_value) 

286 raise ValueError(msg) 

287 

288 self.project_name = parts[1] 

289 self.collection_name = parts[5] 

290 self.database_name = parts[3] 

291 self.document_id = "/".join(parts[6:]) 

292 

293 @property 

294 def full_key(self) -> str: 

295 """Computed property for a DocumentReference's collection_name and 

296 document Id""" 

297 return "/".join([self.collection_name, self.document_id]) 

298 

299 @property 

300 def full_path(self) -> str: 

301 return self._reference_value or "/".join( 

302 [ 

303 "projects", 

304 self.project_name, 

305 "databases", 

306 self.database_name, 

307 "documents", 

308 self.collection_name, 

309 self.document_id, 

310 ] 

311 ) 

312 

313 

314def reference_value_to_document(reference_value, client) -> Any: 

315 """Convert a reference value string to a document. 

316 

317 Args: 

318 reference_value (str): A document reference value. 

319 client (:class:`~google.cloud.firestore_v1.client.Client`): 

320 A client that has a document factory. 

321 

322 Returns: 

323 :class:`~google.cloud.firestore_v1.document.DocumentReference`: 

324 The document corresponding to ``reference_value``. 

325 

326 Raises: 

327 ValueError: If the ``reference_value`` is not of the expected 

328 format: ``projects/{project}/databases/{database}/documents/...``. 

329 ValueError: If the ``reference_value`` does not come from the same 

330 project / database combination as the ``client``. 

331 """ 

332 from google.cloud.firestore_v1.base_document import BaseDocumentReference 

333 

334 doc_ref_value = DocumentReferenceValue(reference_value) 

335 

336 document: BaseDocumentReference = client.document(doc_ref_value.full_key) 

337 if document._document_path != reference_value: 

338 msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string) 

339 raise ValueError(msg) 

340 

341 return document 

342 

343 

344def decode_value( 

345 value, client 

346) -> Union[ 

347 None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint, Vector 

348]: 

349 """Converts a Firestore protobuf ``Value`` to a native Python value. 

350 

351 Args: 

352 value (google.cloud.firestore_v1.types.Value): A 

353 Firestore protobuf to be decoded / parsed / converted. 

354 client (:class:`~google.cloud.firestore_v1.client.Client`): 

355 A client that has a document factory. 

356 

357 Returns: 

358 Union[NoneType, bool, int, float, datetime.datetime, \ 

359 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]: A native 

360 Python value converted from the ``value``. 

361 

362 Raises: 

363 NotImplementedError: If the ``value_type`` is ``reference_value``. 

364 ValueError: If the ``value_type`` is unknown. 

365 """ 

366 value_pb = getattr(value, "_pb", value) 

367 value_type = value_pb.WhichOneof("value_type") 

368 

369 if value_type == "null_value": 

370 return None 

371 elif value_type == "boolean_value": 

372 return value_pb.boolean_value 

373 elif value_type == "integer_value": 

374 return value_pb.integer_value 

375 elif value_type == "double_value": 

376 return value_pb.double_value 

377 elif value_type == "timestamp_value": 

378 return DatetimeWithNanoseconds.from_timestamp_pb(value_pb.timestamp_value) 

379 elif value_type == "string_value": 

380 return value_pb.string_value 

381 elif value_type == "bytes_value": 

382 return value_pb.bytes_value 

383 elif value_type == "reference_value": 

384 return reference_value_to_document(value_pb.reference_value, client) 

385 elif value_type == "geo_point_value": 

386 return GeoPoint( 

387 value_pb.geo_point_value.latitude, value_pb.geo_point_value.longitude 

388 ) 

389 elif value_type == "array_value": 

390 return [ 

391 decode_value(element, client) for element in value_pb.array_value.values 

392 ] 

393 elif value_type == "map_value": 

394 return decode_dict(value_pb.map_value.fields, client) 

395 else: 

396 raise ValueError("Unknown ``value_type``", value_type) 

397 

398 

399def decode_dict(value_fields, client) -> Union[dict, Vector]: 

400 """Converts a protobuf map of Firestore ``Value``-s. 

401 

402 Args: 

403 value_fields (google.protobuf.pyext._message.MessageMapContainer): A 

404 protobuf map of Firestore ``Value``-s. 

405 client (:class:`~google.cloud.firestore_v1.client.Client`): 

406 A client that has a document factory. 

407 

408 Returns: 

409 Dict[str, Union[NoneType, bool, int, float, datetime.datetime, \ 

410 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]]: A dictionary 

411 of native Python values converted from the ``value_fields``. 

412 """ 

413 value_fields_pb = getattr(value_fields, "_pb", value_fields) 

414 res = {key: decode_value(value, client) for key, value in value_fields_pb.items()} 

415 

416 if res.get("__type__", None) == "__vector__": 

417 # Vector data type is represented as mapping. 

418 # {"__type__":"__vector__", "value": [1.0, 2.0, 3.0]}. 

419 values = cast(Sequence[float], res["value"]) 

420 return Vector(values) 

421 

422 return res 

423 

424 

425def get_doc_id(document_pb, expected_prefix) -> str: 

426 """Parse a document ID from a document protobuf. 

427 

428 Args: 

429 document_pb (google.cloud.firestore_v1.\ 

430 document.Document): A protobuf for a document that 

431 was created in a ``CreateDocument`` RPC. 

432 expected_prefix (str): The expected collection prefix for the 

433 fully-qualified document name. 

434 

435 Returns: 

436 str: The document ID from the protobuf. 

437 

438 Raises: 

439 ValueError: If the name does not begin with the prefix. 

440 """ 

441 prefix, document_id = document_pb.name.rsplit(DOCUMENT_PATH_DELIMITER, 1) 

442 if prefix != expected_prefix: 

443 raise ValueError( 

444 "Unexpected document name", 

445 document_pb.name, 

446 "Expected to begin with", 

447 expected_prefix, 

448 ) 

449 

450 return document_id 

451 

452 

453_EmptyDict = transforms.Sentinel("Marker for an empty dict value") 

454 

455 

456def extract_fields( 

457 document_data, prefix_path: FieldPath, expand_dots=False 

458) -> Generator[Tuple[Any, Any], Any, None]: 

459 """Do depth-first walk of tree, yielding field_path, value""" 

460 if not document_data: 

461 yield prefix_path, _EmptyDict 

462 else: 

463 for key, value in sorted(document_data.items()): 

464 if expand_dots: 

465 sub_key = FieldPath.from_string(key) 

466 else: 

467 sub_key = FieldPath(key) 

468 

469 field_path = FieldPath(*(prefix_path.parts + sub_key.parts)) 

470 

471 if isinstance(value, dict): 

472 for s_path, s_value in extract_fields(value, field_path): 

473 yield s_path, s_value 

474 else: 

475 yield field_path, value 

476 

477 

478def set_field_value(document_data, field_path, value) -> None: 

479 """Set a value into a document for a field_path""" 

480 current = document_data 

481 for element in field_path.parts[:-1]: 

482 current = current.setdefault(element, {}) 

483 if value is _EmptyDict: 

484 value = {} 

485 current[field_path.parts[-1]] = value 

486 

487 

488def get_field_value(document_data, field_path) -> Any: 

489 if not field_path.parts: 

490 raise ValueError("Empty path") 

491 

492 current = document_data 

493 for element in field_path.parts[:-1]: 

494 current = current[element] 

495 return current[field_path.parts[-1]] 

496 

497 

498class DocumentExtractor(object): 

499 """Break document data up into actual data and transforms. 

500 

501 Handle special values such as ``DELETE_FIELD``, ``SERVER_TIMESTAMP``. 

502 

503 Args: 

504 document_data (dict): 

505 Property names and values to use for sending a change to 

506 a document. 

507 """ 

508 

509 def __init__(self, document_data) -> None: 

510 self.document_data = document_data 

511 self.field_paths = [] 

512 self.deleted_fields = [] 

513 self.server_timestamps = [] 

514 self.array_removes = {} 

515 self.array_unions = {} 

516 self.increments = {} 

517 self.minimums = {} 

518 self.maximums = {} 

519 self.set_fields: dict = {} 

520 self.empty_document = False 

521 

522 prefix_path = FieldPath() 

523 iterator = self._get_document_iterator(prefix_path) 

524 

525 for field_path, value in iterator: 

526 if field_path == prefix_path and value is _EmptyDict: 

527 self.empty_document = True 

528 

529 elif value is transforms.DELETE_FIELD: 

530 self.deleted_fields.append(field_path) 

531 

532 elif value is transforms.SERVER_TIMESTAMP: 

533 self.server_timestamps.append(field_path) 

534 

535 elif isinstance(value, transforms.ArrayRemove): 

536 self.array_removes[field_path] = value.values 

537 

538 elif isinstance(value, transforms.ArrayUnion): 

539 self.array_unions[field_path] = value.values 

540 

541 elif isinstance(value, transforms.Increment): 

542 self.increments[field_path] = value.value 

543 

544 elif isinstance(value, transforms.Maximum): 

545 self.maximums[field_path] = value.value 

546 

547 elif isinstance(value, transforms.Minimum): 

548 self.minimums[field_path] = value.value 

549 

550 else: 

551 self.field_paths.append(field_path) 

552 set_field_value(self.set_fields, field_path, value) 

553 

554 def _get_document_iterator( 

555 self, prefix_path: FieldPath 

556 ) -> Generator[Tuple[Any, Any], Any, None]: 

557 return extract_fields(self.document_data, prefix_path) 

558 

559 @property 

560 def has_transforms(self): 

561 return bool( 

562 self.server_timestamps 

563 or self.array_removes 

564 or self.array_unions 

565 or self.increments 

566 or self.maximums 

567 or self.minimums 

568 ) 

569 

570 @property 

571 def transform_paths(self): 

572 return sorted( 

573 self.server_timestamps 

574 + list(self.array_removes) 

575 + list(self.array_unions) 

576 + list(self.increments) 

577 + list(self.maximums) 

578 + list(self.minimums) 

579 ) 

580 

581 def _get_update_mask( 

582 self, allow_empty_mask=False 

583 ) -> Optional[types.common.DocumentMask]: 

584 return None 

585 

586 def get_update_pb( 

587 self, document_path, exists=None, allow_empty_mask=False 

588 ) -> types.write.Write: 

589 if exists is not None: 

590 current_document = common.Precondition(exists=exists) 

591 else: 

592 current_document = None 

593 

594 update_pb = write.Write( 

595 update=document.Document( 

596 name=document_path, fields=encode_dict(self.set_fields) 

597 ), 

598 update_mask=self._get_update_mask(allow_empty_mask), 

599 current_document=current_document, 

600 ) 

601 

602 return update_pb 

603 

604 def get_field_transform_pbs( 

605 self, document_path 

606 ) -> List[types.write.DocumentTransform.FieldTransform]: 

607 def make_array_value(values): 

608 value_list = [encode_value(element) for element in values] 

609 return document.ArrayValue(values=value_list) 

610 

611 path_field_transforms = ( 

612 [ 

613 ( 

614 path, 

615 write.DocumentTransform.FieldTransform( 

616 field_path=path.to_api_repr(), 

617 set_to_server_value=REQUEST_TIME_ENUM, 

618 ), 

619 ) 

620 for path in self.server_timestamps 

621 ] 

622 + [ 

623 ( 

624 path, 

625 write.DocumentTransform.FieldTransform( 

626 field_path=path.to_api_repr(), 

627 remove_all_from_array=make_array_value(values), 

628 ), 

629 ) 

630 for path, values in self.array_removes.items() 

631 ] 

632 + [ 

633 ( 

634 path, 

635 write.DocumentTransform.FieldTransform( 

636 field_path=path.to_api_repr(), 

637 append_missing_elements=make_array_value(values), 

638 ), 

639 ) 

640 for path, values in self.array_unions.items() 

641 ] 

642 + [ 

643 ( 

644 path, 

645 write.DocumentTransform.FieldTransform( 

646 field_path=path.to_api_repr(), increment=encode_value(value) 

647 ), 

648 ) 

649 for path, value in self.increments.items() 

650 ] 

651 + [ 

652 ( 

653 path, 

654 write.DocumentTransform.FieldTransform( 

655 field_path=path.to_api_repr(), maximum=encode_value(value) 

656 ), 

657 ) 

658 for path, value in self.maximums.items() 

659 ] 

660 + [ 

661 ( 

662 path, 

663 write.DocumentTransform.FieldTransform( 

664 field_path=path.to_api_repr(), minimum=encode_value(value) 

665 ), 

666 ) 

667 for path, value in self.minimums.items() 

668 ] 

669 ) 

670 return [transform for path, transform in sorted(path_field_transforms)] 

671 

672 def get_transform_pb(self, document_path, exists=None) -> types.write.Write: 

673 field_transforms = self.get_field_transform_pbs(document_path) 

674 transform_pb = write.Write( 

675 transform=write.DocumentTransform( 

676 document=document_path, field_transforms=field_transforms 

677 ) 

678 ) 

679 if exists is not None: 

680 transform_pb._pb.current_document.CopyFrom( 

681 common.Precondition(exists=exists)._pb 

682 ) 

683 

684 return transform_pb 

685 

686 

687def pbs_for_create(document_path, document_data) -> List[types.write.Write]: 

688 """Make ``Write`` protobufs for ``create()`` methods. 

689 

690 Args: 

691 document_path (str): A fully-qualified document path. 

692 document_data (dict): Property names and values to use for 

693 creating a document. 

694 

695 Returns: 

696 List[google.cloud.firestore_v1.types.Write]: One or two 

697 ``Write`` protobuf instances for ``create()``. 

698 """ 

699 extractor = DocumentExtractor(document_data) 

700 

701 if extractor.deleted_fields: 

702 raise ValueError("Cannot apply DELETE_FIELD in a create request.") 

703 

704 create_pb = extractor.get_update_pb(document_path, exists=False) 

705 

706 if extractor.has_transforms: 

707 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

708 create_pb.update_transforms.extend(field_transform_pbs) 

709 

710 return [create_pb] 

711 

712 

713def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]: 

714 """Make ``Write`` protobufs for ``set()`` methods. 

715 

716 Args: 

717 document_path (str): A fully-qualified document path. 

718 document_data (dict): Property names and values to use for 

719 replacing a document. 

720 

721 Returns: 

722 List[google.cloud.firestore_v1.types.Write]: One 

723 or two ``Write`` protobuf instances for ``set()``. 

724 """ 

725 extractor = DocumentExtractor(document_data) 

726 

727 if extractor.deleted_fields: 

728 raise ValueError( 

729 "Cannot apply DELETE_FIELD in a set request without " 

730 "specifying 'merge=True' or 'merge=[field_paths]'." 

731 ) 

732 

733 set_pb = extractor.get_update_pb(document_path) 

734 

735 if extractor.has_transforms: 

736 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

737 set_pb.update_transforms.extend(field_transform_pbs) 

738 

739 return [set_pb] 

740 

741 

742class DocumentExtractorForMerge(DocumentExtractor): 

743 """Break document data up into actual data and transforms.""" 

744 

745 def __init__(self, document_data) -> None: 

746 super(DocumentExtractorForMerge, self).__init__(document_data) 

747 self.data_merge: list = [] 

748 self.transform_merge: list = [] 

749 self.merge: list = [] 

750 

751 def _apply_merge_all(self) -> None: 

752 self.data_merge = sorted(self.field_paths + self.deleted_fields) 

753 # TODO: other transforms 

754 self.transform_merge = self.transform_paths 

755 self.merge = sorted(self.data_merge + self.transform_paths) 

756 

757 def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]: 

758 for merge_field in merge: 

759 if isinstance(merge_field, FieldPath): 

760 yield merge_field 

761 else: 

762 yield FieldPath(*parse_field_path(merge_field)) 

763 

764 def _normalize_merge_paths(self, merge) -> list: 

765 merge_paths = sorted(self._construct_merge_paths(merge)) 

766 

767 # Raise if any merge path is a parent of another. Leverage sorting 

768 # to avoid quadratic behavior. 

769 for index in range(len(merge_paths) - 1): 

770 lhs, rhs = merge_paths[index], merge_paths[index + 1] 

771 if lhs.eq_or_parent(rhs): 

772 raise ValueError("Merge paths overlap: {}, {}".format(lhs, rhs)) 

773 

774 for merge_path in merge_paths: 

775 if merge_path in self.deleted_fields: 

776 continue 

777 try: 

778 get_field_value(self.document_data, merge_path) 

779 except KeyError: 

780 raise ValueError("Invalid merge path: {}".format(merge_path)) 

781 

782 return merge_paths 

783 

784 def _apply_merge_paths(self, merge) -> None: 

785 if self.empty_document: 

786 raise ValueError("Cannot merge specific fields with empty document.") 

787 

788 merge_paths = self._normalize_merge_paths(merge) 

789 

790 del self.data_merge[:] 

791 del self.transform_merge[:] 

792 self.merge = merge_paths 

793 

794 for merge_path in merge_paths: 

795 if merge_path in self.transform_paths: 

796 self.transform_merge.append(merge_path) 

797 

798 for field_path in self.field_paths: 

799 if merge_path.eq_or_parent(field_path): 

800 self.data_merge.append(field_path) 

801 

802 # Clear out data for fields not merged. 

803 merged_set_fields: dict = {} 

804 for field_path in self.data_merge: 

805 value = get_field_value(self.document_data, field_path) 

806 set_field_value(merged_set_fields, field_path, value) 

807 self.set_fields = merged_set_fields 

808 

809 unmerged_deleted_fields = [ 

810 field_path 

811 for field_path in self.deleted_fields 

812 if field_path not in self.merge 

813 ] 

814 if unmerged_deleted_fields: 

815 raise ValueError( 

816 "Cannot delete unmerged fields: {}".format(unmerged_deleted_fields) 

817 ) 

818 self.data_merge = sorted(self.data_merge + self.deleted_fields) 

819 

820 # Keep only transforms which are within merge. 

821 merged_transform_paths = set() 

822 for merge_path in self.merge: 

823 tranform_merge_paths = [ 

824 transform_path 

825 for transform_path in self.transform_paths 

826 if merge_path.eq_or_parent(transform_path) 

827 ] 

828 merged_transform_paths.update(tranform_merge_paths) 

829 

830 self.server_timestamps = [ 

831 path for path in self.server_timestamps if path in merged_transform_paths 

832 ] 

833 

834 self.array_removes = { 

835 path: values 

836 for path, values in self.array_removes.items() 

837 if path in merged_transform_paths 

838 } 

839 

840 self.array_unions = { 

841 path: values 

842 for path, values in self.array_unions.items() 

843 if path in merged_transform_paths 

844 } 

845 

846 def apply_merge(self, merge) -> None: 

847 if merge is True: # merge all fields 

848 self._apply_merge_all() 

849 else: 

850 self._apply_merge_paths(merge) 

851 

852 def _get_update_mask( 

853 self, allow_empty_mask=False 

854 ) -> Optional[types.common.DocumentMask]: 

855 # Mask uses dotted / quoted paths. 

856 mask_paths = [ 

857 field_path.to_api_repr() 

858 for field_path in self.merge 

859 if field_path not in self.transform_merge 

860 ] 

861 

862 return common.DocumentMask(field_paths=mask_paths) 

863 

864 

865def pbs_for_set_with_merge( 

866 document_path, document_data, merge 

867) -> List[types.write.Write]: 

868 """Make ``Write`` protobufs for ``set()`` methods. 

869 

870 Args: 

871 document_path (str): A fully-qualified document path. 

872 document_data (dict): Property names and values to use for 

873 replacing a document. 

874 merge (Optional[bool] or Optional[List<apispec>]): 

875 If True, merge all fields; else, merge only the named fields. 

876 

877 Returns: 

878 List[google.cloud.firestore_v1.types.Write]: One 

879 or two ``Write`` protobuf instances for ``set()``. 

880 """ 

881 extractor = DocumentExtractorForMerge(document_data) 

882 extractor.apply_merge(merge) 

883 

884 set_pb = extractor.get_update_pb(document_path) 

885 

886 if extractor.transform_paths: 

887 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

888 set_pb.update_transforms.extend(field_transform_pbs) 

889 

890 return [set_pb] 

891 

892 

893class DocumentExtractorForUpdate(DocumentExtractor): 

894 """Break document data up into actual data and transforms.""" 

895 

896 def __init__(self, document_data) -> None: 

897 super(DocumentExtractorForUpdate, self).__init__(document_data) 

898 self.top_level_paths = sorted( 

899 [FieldPath.from_string(key) for key in document_data] 

900 ) 

901 tops = set(self.top_level_paths) 

902 for top_level_path in self.top_level_paths: 

903 for ancestor in top_level_path.lineage(): 

904 if ancestor in tops: 

905 raise ValueError( 

906 "Conflicting field path: {}, {}".format( 

907 top_level_path, ancestor 

908 ) 

909 ) 

910 

911 for field_path in self.deleted_fields: 

912 if field_path not in tops: 

913 raise ValueError( 

914 "Cannot update with nest delete: {}".format(field_path) 

915 ) 

916 

917 def _get_document_iterator( 

918 self, prefix_path: FieldPath 

919 ) -> Generator[Tuple[Any, Any], Any, None]: 

920 return extract_fields(self.document_data, prefix_path, expand_dots=True) 

921 

922 def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask: 

923 mask_paths = [] 

924 for field_path in self.top_level_paths: 

925 if field_path not in self.transform_paths: 

926 mask_paths.append(field_path.to_api_repr()) 

927 

928 return common.DocumentMask(field_paths=mask_paths) 

929 

930 

931def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]: 

932 """Make ``Write`` protobufs for ``update()`` methods. 

933 

934 Args: 

935 document_path (str): A fully-qualified document path. 

936 field_updates (dict): Field names or paths to update and values 

937 to update with. 

938 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): 

939 A write option to make assertions / preconditions on the server 

940 state of the document before applying changes. 

941 

942 Returns: 

943 List[google.cloud.firestore_v1.types.Write]: One 

944 or two ``Write`` protobuf instances for ``update()``. 

945 """ 

946 extractor = DocumentExtractorForUpdate(field_updates) 

947 

948 if extractor.empty_document: 

949 raise ValueError("Cannot update with an empty document.") 

950 

951 if option is None: # Default is to use ``exists=True``. 

952 option = ExistsOption(exists=True) 

953 

954 update_pb = extractor.get_update_pb(document_path) 

955 option.modify_write(update_pb) 

956 

957 if extractor.has_transforms: 

958 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

959 update_pb.update_transforms.extend(field_transform_pbs) 

960 

961 return [update_pb] 

962 

963 

964def pb_for_delete(document_path, option) -> types.write.Write: 

965 """Make a ``Write`` protobuf for ``delete()`` methods. 

966 

967 Args: 

968 document_path (str): A fully-qualified document path. 

969 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): 

970 A write option to make assertions / preconditions on the server 

971 state of the document before applying changes. 

972 

973 Returns: 

974 google.cloud.firestore_v1.types.Write: A 

975 ``Write`` protobuf instance for the ``delete()``. 

976 """ 

977 write_pb = write.Write(delete=document_path) 

978 if option is not None: 

979 option.modify_write(write_pb) 

980 

981 return write_pb 

982 

983 

984class ReadAfterWriteError(Exception): 

985 """Raised when a read is attempted after a write. 

986 

987 Raised by "read" methods that use transactions. 

988 """ 

989 

990 

991def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]: 

992 """Get the transaction ID from a ``Transaction`` object. 

993 

994 Args: 

995 transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ 

996 Transaction`]): 

997 An existing transaction that this query will run in. 

998 read_operation (Optional[bool]): Indicates if the transaction ID 

999 will be used in a read operation. Defaults to :data:`True`. 

1000 

1001 Returns: 

1002 Optional[bytes]: The ID of the transaction, or :data:`None` if the 

1003 ``transaction`` is :data:`None`. 

1004 

1005 Raises: 

1006 ValueError: If the ``transaction`` is not in progress (only if 

1007 ``transaction`` is not :data:`None`). 

1008 ReadAfterWriteError: If the ``transaction`` has writes stored on 

1009 it and ``read_operation`` is :data:`True`. 

1010 """ 

1011 if transaction is None: 

1012 return None 

1013 else: 

1014 if not transaction.in_progress: 

1015 raise ValueError(INACTIVE_TXN) 

1016 if read_operation and len(transaction._write_pbs) > 0: 

1017 raise ReadAfterWriteError(READ_AFTER_WRITE_ERROR) 

1018 return transaction.id 

1019 

1020 

1021def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: 

1022 """Create RPC metadata containing a prefix. 

1023 

1024 Args: 

1025 prefix (str): appropriate resource path. 

1026 

1027 Returns: 

1028 List[Tuple[str, str]]: RPC metadata with supplied prefix 

1029 """ 

1030 return [("google-cloud-resource-prefix", prefix)] 

1031 

1032 

1033class WriteOption(object): 

1034 """Option used to assert a condition on a write operation.""" 

1035 

1036 def modify_write(self, write, no_create_msg=None) -> None: 

1037 """Modify a ``Write`` protobuf based on the state of this write option. 

1038 

1039 This is a virtual method intended to be implemented by subclasses. 

1040 

1041 Args: 

1042 write (google.cloud.firestore_v1.types.Write): A 

1043 ``Write`` protobuf instance to be modified with a precondition 

1044 determined by the state of this option. 

1045 no_create_msg (Optional[str]): A message to use to indicate that 

1046 a create operation is not allowed. 

1047 

1048 Raises: 

1049 NotImplementedError: Always, this method is virtual. 

1050 """ 

1051 raise NotImplementedError 

1052 

1053 

1054class LastUpdateOption(WriteOption): 

1055 """Option used to assert a "last update" condition on a write operation. 

1056 

1057 This will typically be created by 

1058 :meth:`~google.cloud.firestore_v1.client.Client.write_option`. 

1059 

1060 Args: 

1061 last_update_time (google.protobuf.timestamp_pb2.Timestamp): A 

1062 timestamp. When set, the target document must exist and have 

1063 been last updated at that time. Protobuf ``update_time`` timestamps 

1064 are typically returned from methods that perform write operations 

1065 as part of a "write result" protobuf or directly. 

1066 """ 

1067 

1068 def __init__(self, last_update_time) -> None: 

1069 self._last_update_time = last_update_time 

1070 

1071 def __eq__(self, other): 

1072 if not isinstance(other, self.__class__): 

1073 return NotImplemented 

1074 return self._last_update_time == other._last_update_time 

1075 

1076 def modify_write(self, write, *unused_args, **unused_kwargs) -> None: 

1077 """Modify a ``Write`` protobuf based on the state of this write option. 

1078 

1079 The ``last_update_time`` is added to ``write_pb`` as an "update time" 

1080 precondition. When set, the target document must exist and have been 

1081 last updated at that time. 

1082 

1083 Args: 

1084 write_pb (google.cloud.firestore_v1.types.Write): A 

1085 ``Write`` protobuf instance to be modified with a precondition 

1086 determined by the state of this option. 

1087 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by 

1088 other subclasses that are unused here. 

1089 """ 

1090 current_doc = types.Precondition(update_time=self._last_update_time) 

1091 write._pb.current_document.CopyFrom(current_doc._pb) 

1092 

1093 

1094class ExistsOption(WriteOption): 

1095 """Option used to assert existence on a write operation. 

1096 

1097 This will typically be created by 

1098 :meth:`~google.cloud.firestore_v1.client.Client.write_option`. 

1099 

1100 Args: 

1101 exists (bool): Indicates if the document being modified 

1102 should already exist. 

1103 """ 

1104 

1105 def __init__(self, exists) -> None: 

1106 self._exists = exists 

1107 

1108 def __eq__(self, other): 

1109 if not isinstance(other, self.__class__): 

1110 return NotImplemented 

1111 return self._exists == other._exists 

1112 

1113 def modify_write(self, write, *unused_args, **unused_kwargs) -> None: 

1114 """Modify a ``Write`` protobuf based on the state of this write option. 

1115 

1116 If: 

1117 

1118 * ``exists=True``, adds a precondition that requires existence 

1119 * ``exists=False``, adds a precondition that requires non-existence 

1120 

1121 Args: 

1122 write (google.cloud.firestore_v1.types.Write): A 

1123 ``Write`` protobuf instance to be modified with a precondition 

1124 determined by the state of this option. 

1125 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by 

1126 other subclasses that are unused here. 

1127 """ 

1128 current_doc = types.Precondition(exists=self._exists) 

1129 write._pb.current_document.CopyFrom(current_doc._pb) 

1130 

1131 

1132def make_retry_timeout_kwargs( 

1133 retry: retries.Retry | retries.AsyncRetry | object | None, timeout: float | None 

1134) -> dict: 

1135 """Helper fo API methods which take optional 'retry' / 'timeout' args.""" 

1136 kwargs = {} 

1137 

1138 if retry is not gapic_v1.method.DEFAULT: 

1139 kwargs["retry"] = retry 

1140 

1141 if timeout is not None: 

1142 kwargs["timeout"] = timeout 

1143 

1144 return kwargs 

1145 

1146 

1147def build_timestamp( 

1148 dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None 

1149) -> Timestamp: 

1150 """Returns the supplied datetime (or "now") as a Timestamp""" 

1151 return _datetime_to_pb_timestamp( 

1152 dt or DatetimeWithNanoseconds.now(tz=datetime.timezone.utc) 

1153 ) 

1154 

1155 

1156def compare_timestamps( 

1157 ts1: Union[Timestamp, datetime.datetime], 

1158 ts2: Union[Timestamp, datetime.datetime], 

1159) -> int: 

1160 ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1 

1161 ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2 

1162 ts1_nanos = ts1.nanos + ts1.seconds * 1e9 

1163 ts2_nanos = ts2.nanos + ts2.seconds * 1e9 

1164 if ts1_nanos == ts2_nanos: 

1165 return 0 

1166 return 1 if ts1_nanos > ts2_nanos else -1 

1167 

1168 

1169def deserialize_bundle( 

1170 serialized: Union[str, bytes], 

1171 client: "google.cloud.firestore_v1.client.BaseClient", 

1172) -> "google.cloud.firestore_bundle.FirestoreBundle": 

1173 """Inverse operation to a `FirestoreBundle` instance's `build()` method. 

1174 

1175 Args: 

1176 serialized (Union[str, bytes]): The result of `FirestoreBundle.build()`. 

1177 Should be a list of dictionaries in string format. 

1178 client (BaseClient): A connected Client instance. 

1179 

1180 Returns: 

1181 FirestoreBundle: A bundle equivalent to that which called `build()` and 

1182 initially created the `serialized` value. 

1183 

1184 Raises: 

1185 ValueError: If any of the dictionaries in the list contain any more than 

1186 one top-level key. 

1187 ValueError: If any unexpected BundleElement types are encountered. 

1188 ValueError: If the serialized bundle ends before expected. 

1189 """ 

1190 from google.cloud.firestore_bundle import BundleElement, FirestoreBundle 

1191 

1192 # Outlines the legal transitions from one BundleElement to another. 

1193 bundle_state_machine = { 

1194 "__initial__": ["metadata"], 

1195 "metadata": ["namedQuery", "documentMetadata", "__end__"], 

1196 "namedQuery": ["namedQuery", "documentMetadata", "__end__"], 

1197 "documentMetadata": ["document"], 

1198 "document": ["documentMetadata", "__end__"], 

1199 } 

1200 allowed_next_element_types: List[str] = bundle_state_machine["__initial__"] 

1201 

1202 # This must be saved and added last, since we cache it to preserve timestamps, 

1203 # yet must flush it whenever a new document or query is added to a bundle. 

1204 # The process of deserializing a bundle uses these methods which flush a 

1205 # cached metadata element, and thus, it must be the last BundleElement 

1206 # added during deserialization. 

1207 metadata_bundle_element: Optional[BundleElement] = None 

1208 

1209 bundle: Optional[FirestoreBundle] = None 

1210 data: Dict 

1211 for data in _parse_bundle_elements_data(serialized): 

1212 # BundleElements are serialized as JSON containing one key outlining 

1213 # the type, with all further data nested under that key 

1214 keys: List[str] = list(data.keys()) 

1215 

1216 if len(keys) != 1: 

1217 raise ValueError("Expected serialized BundleElement with one top-level key") 

1218 

1219 key: str = keys[0] 

1220 

1221 if key not in allowed_next_element_types: 

1222 raise ValueError( 

1223 f"Encountered BundleElement of type {key}. " 

1224 f"Expected one of {allowed_next_element_types}" 

1225 ) 

1226 

1227 # Create and add our BundleElement 

1228 bundle_element: BundleElement 

1229 try: 

1230 bundle_element = BundleElement.from_json(json.dumps(data)) 

1231 except AttributeError as e: 

1232 # Some bad serialization formats cannot be universally deserialized. 

1233 if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER 

1234 raise ValueError( 

1235 "Invalid serialization of datetimes. " 

1236 "Cannot deserialize Bundles created from the NodeJS SDK." 

1237 ) 

1238 raise e # pragma: NO COVER 

1239 

1240 if bundle is None: 

1241 # This must be the first bundle type encountered 

1242 assert key == "metadata" 

1243 bundle = FirestoreBundle(data[key]["id"]) 

1244 metadata_bundle_element = bundle_element 

1245 

1246 else: 

1247 bundle._add_bundle_element(bundle_element, client=client, type=key) 

1248 

1249 # Update the allowed next BundleElement types 

1250 allowed_next_element_types = bundle_state_machine[key] 

1251 

1252 if "__end__" not in allowed_next_element_types: 

1253 raise ValueError("Unexpected end to serialized FirestoreBundle") 

1254 # state machine guarantees bundle and metadata have been populated 

1255 bundle = cast(FirestoreBundle, bundle) 

1256 metadata_bundle_element = cast(BundleElement, metadata_bundle_element) 

1257 # Now, finally add the metadata element 

1258 bundle._add_bundle_element( 

1259 metadata_bundle_element, 

1260 client=client, 

1261 type="metadata", 

1262 ) 

1263 

1264 return bundle 

1265 

1266 

1267def _parse_bundle_elements_data( 

1268 serialized: Union[str, bytes] 

1269) -> Generator[Dict, None, None]: 

1270 """Reads through a serialized FirestoreBundle and yields JSON chunks that 

1271 were created via `BundleElement.to_json(bundle_element)`. 

1272 

1273 Serialized FirestoreBundle instances are length-prefixed JSON objects, and 

1274 so are of the form "123{...}57{...}" 

1275 To correctly and safely read a bundle, we must first detect these length 

1276 prefixes, read that many bytes of data, and attempt to JSON-parse that. 

1277 

1278 Raises: 

1279 ValueError: If a chunk of JSON ever starts without following a length 

1280 prefix. 

1281 """ 

1282 _serialized: Iterator[int] = iter( 

1283 serialized if isinstance(serialized, bytes) else serialized.encode("utf-8") 

1284 ) 

1285 

1286 length_prefix: str = "" 

1287 while True: 

1288 byte: Optional[int] = next(_serialized, None) 

1289 

1290 if byte is None: 

1291 return None 

1292 

1293 _str: str = chr(byte) 

1294 if _str.isnumeric(): 

1295 length_prefix += _str 

1296 else: 

1297 if length_prefix == "": 

1298 raise ValueError("Expected length prefix") 

1299 

1300 _length_prefix = int(length_prefix) 

1301 length_prefix = "" 

1302 _bytes = bytearray([byte]) 

1303 _counter = 1 

1304 while _counter < _length_prefix: 

1305 _bytes.append(next(_serialized)) 

1306 _counter += 1 

1307 

1308 yield json.loads(_bytes.decode("utf-8")) 

1309 

1310 

1311def _get_documents_from_bundle( 

1312 bundle, *, query_name: Optional[str] = None 

1313) -> Generator["DocumentSnapshot", None, None]: 

1314 from google.cloud.firestore_bundle.bundle import _BundledDocument 

1315 

1316 bundled_doc: _BundledDocument 

1317 for bundled_doc in bundle.documents.values(): 

1318 if query_name and query_name not in bundled_doc.metadata.queries: 

1319 continue 

1320 yield bundled_doc.snapshot 

1321 

1322 

1323def _get_document_from_bundle( 

1324 bundle, 

1325 *, 

1326 document_id: str, 

1327) -> Optional["DocumentSnapshot"]: 

1328 bundled_doc = bundle.documents.get(document_id) 

1329 if bundled_doc: 

1330 return bundled_doc.snapshot 

1331 else: 

1332 return None