Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/google/cloud/firestore_v1/_helpers.py: 22%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

489 statements  

1# Copyright 2017 Google LLC All rights reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Common helpers shared across Google Cloud Firestore modules.""" 

16from __future__ import annotations 

17import datetime 

18import json 

19from typing import ( 

20 Any, 

21 Dict, 

22 Generator, 

23 Iterator, 

24 List, 

25 Optional, 

26 Sequence, 

27 Tuple, 

28 Union, 

29 cast, 

30 TYPE_CHECKING, 

31) 

32 

33import grpc # type: ignore 

34from google.api_core import gapic_v1 

35from google.api_core import retry as retries 

36from google.api_core.datetime_helpers import DatetimeWithNanoseconds 

37from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore 

38from google.protobuf import struct_pb2 

39from google.protobuf.timestamp_pb2 import Timestamp # type: ignore 

40from google.type import latlng_pb2 # type: ignore 

41 

42import google 

43from google.cloud import exceptions # type: ignore 

44from google.cloud.firestore_v1 import transforms, types 

45from google.cloud.firestore_v1.field_path import FieldPath, parse_field_path 

46from google.cloud.firestore_v1.types import common, document, write 

47from google.cloud.firestore_v1.types.write import DocumentTransform 

48from google.cloud.firestore_v1.vector import Vector 

49 

50if TYPE_CHECKING: # pragma: NO COVER 

51 from google.cloud.firestore_v1 import DocumentSnapshot 

52 

53_EmptyDict: transforms.Sentinel 

54_GRPC_ERROR_MAPPING: dict 

55 

56 

57BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}." 

58DOCUMENT_PATH_DELIMITER = "/" 

59INACTIVE_TXN = "Transaction not in progress, cannot be used in API requests." 

60READ_AFTER_WRITE_ERROR = "Attempted read after write in a transaction." 

61BAD_REFERENCE_ERROR = ( 

62 "Reference value {!r} in unexpected format, expected to be of the form " 

63 "``projects/{{project}}/databases/{{database}}/" 

64 "documents/{{document_path}}``." 

65) 

66WRONG_APP_REFERENCE = ( 

67 "Document {!r} does not correspond to the same database " "({!r}) as the client." 

68) 

69REQUEST_TIME_ENUM = DocumentTransform.FieldTransform.ServerValue.REQUEST_TIME 

70_GRPC_ERROR_MAPPING = { 

71 grpc.StatusCode.ALREADY_EXISTS: exceptions.Conflict, 

72 grpc.StatusCode.NOT_FOUND: exceptions.NotFound, 

73} 

74 

75 

76class GeoPoint(object): 

77 """Simple container for a geo point value. 

78 

79 Args: 

80 latitude (float): Latitude of a point. 

81 longitude (float): Longitude of a point. 

82 """ 

83 

84 def __init__(self, latitude, longitude) -> None: 

85 self.latitude = latitude 

86 self.longitude = longitude 

87 

88 def to_protobuf(self) -> latlng_pb2.LatLng: 

89 """Convert the current object to protobuf. 

90 

91 Returns: 

92 google.type.latlng_pb2.LatLng: The current point as a protobuf. 

93 """ 

94 return latlng_pb2.LatLng(latitude=self.latitude, longitude=self.longitude) 

95 

96 def __eq__(self, other): 

97 """Compare two geo points for equality. 

98 

99 Returns: 

100 Union[bool, NotImplemented]: :data:`True` if the points compare 

101 equal, else :data:`False`. (Or :data:`NotImplemented` if 

102 ``other`` is not a geo point.) 

103 """ 

104 if not isinstance(other, GeoPoint): 

105 return NotImplemented 

106 

107 return self.latitude == other.latitude and self.longitude == other.longitude 

108 

109 def __ne__(self, other): 

110 """Compare two geo points for inequality. 

111 

112 Returns: 

113 Union[bool, NotImplemented]: :data:`False` if the points compare 

114 equal, else :data:`True`. (Or :data:`NotImplemented` if 

115 ``other`` is not a geo point.) 

116 """ 

117 equality_val = self.__eq__(other) 

118 if equality_val is NotImplemented: 

119 return NotImplemented 

120 else: 

121 return not equality_val 

122 

123 

124def verify_path(path, is_collection) -> None: 

125 """Verifies that a ``path`` has the correct form. 

126 

127 Checks that all of the elements in ``path`` are strings. 

128 

129 Args: 

130 path (Tuple[str, ...]): The components in a collection or 

131 document path. 

132 is_collection (bool): Indicates if the ``path`` represents 

133 a document or a collection. 

134 

135 Raises: 

136 ValueError: if 

137 

138 * the ``path`` is empty 

139 * ``is_collection=True`` and there are an even number of elements 

140 * ``is_collection=False`` and there are an odd number of elements 

141 * an element is not a string 

142 """ 

143 num_elements = len(path) 

144 if num_elements == 0: 

145 raise ValueError("Document or collection path cannot be empty") 

146 

147 if is_collection: 

148 if num_elements % 2 == 0: 

149 raise ValueError("A collection must have an odd number of path elements") 

150 

151 else: 

152 if num_elements % 2 == 1: 

153 raise ValueError("A document must have an even number of path elements") 

154 

155 for element in path: 

156 if not isinstance(element, str): 

157 msg = BAD_PATH_TEMPLATE.format(element, type(element)) 

158 raise ValueError(msg) 

159 

160 

161def encode_value(value) -> types.document.Value: 

162 """Converts a native Python value into a Firestore protobuf ``Value``. 

163 

164 Args: 

165 value (Union[NoneType, bool, int, float, datetime.datetime, \ 

166 str, bytes, dict, ~google.cloud.Firestore.GeoPoint, \ 

167 ~google.cloud.firestore_v1.vector.Vector]): A native 

168 Python value to convert to a protobuf field. 

169 

170 Returns: 

171 ~google.cloud.firestore_v1.types.Value: A 

172 value encoded as a Firestore protobuf. 

173 

174 Raises: 

175 TypeError: If the ``value`` is not one of the accepted types. 

176 """ 

177 if value is None: 

178 return document.Value(null_value=struct_pb2.NULL_VALUE) 

179 

180 # Must come before int since ``bool`` is an integer subtype. 

181 if isinstance(value, bool): 

182 return document.Value(boolean_value=value) 

183 

184 if isinstance(value, int): 

185 return document.Value(integer_value=value) 

186 

187 if isinstance(value, float): 

188 return document.Value(double_value=value) 

189 

190 if isinstance(value, DatetimeWithNanoseconds): 

191 return document.Value(timestamp_value=value.timestamp_pb()) 

192 

193 if isinstance(value, datetime.datetime): 

194 return document.Value(timestamp_value=_datetime_to_pb_timestamp(value)) 

195 

196 if isinstance(value, str): 

197 return document.Value(string_value=value) 

198 

199 if isinstance(value, bytes): 

200 return document.Value(bytes_value=value) 

201 

202 # NOTE: We avoid doing an isinstance() check for a Document 

203 # here to avoid import cycles. 

204 document_path = getattr(value, "_document_path", None) 

205 if document_path is not None: 

206 return document.Value(reference_value=document_path) 

207 

208 if isinstance(value, GeoPoint): 

209 return document.Value(geo_point_value=value.to_protobuf()) 

210 

211 if isinstance(value, (list, tuple, set, frozenset)): 

212 value_list = tuple(encode_value(element) for element in value) 

213 value_pb = document.ArrayValue(values=value_list) 

214 return document.Value(array_value=value_pb) 

215 

216 if isinstance(value, Vector): 

217 return encode_value(value.to_map_value()) 

218 

219 if isinstance(value, dict): 

220 value_dict = encode_dict(value) 

221 value_pb = document.MapValue(fields=value_dict) 

222 return document.Value(map_value=value_pb) 

223 

224 raise TypeError( 

225 "Cannot convert to a Firestore Value", value, "Invalid type", type(value) 

226 ) 

227 

228 

229def encode_dict(values_dict) -> dict: 

230 """Encode a dictionary into protobuf ``Value``-s. 

231 

232 Args: 

233 values_dict (dict): The dictionary to encode as protobuf fields. 

234 

235 Returns: 

236 Dict[str, ~google.cloud.firestore_v1.types.Value]: A 

237 dictionary of string keys and ``Value`` protobufs as dictionary 

238 values. 

239 """ 

240 return {key: encode_value(value) for key, value in values_dict.items()} 

241 

242 

243def document_snapshot_to_protobuf( 

244 snapshot: "DocumentSnapshot", 

245) -> Optional["google.cloud.firestore_v1.types.Document"]: 

246 from google.cloud.firestore_v1.types import Document 

247 

248 if not snapshot.exists: 

249 return None 

250 

251 return Document( 

252 name=snapshot.reference._document_path, 

253 fields=encode_dict(snapshot._data), 

254 create_time=snapshot.create_time, 

255 update_time=snapshot.update_time, 

256 ) 

257 

258 

259class DocumentReferenceValue: 

260 """DocumentReference path container with accessors for each relevant chunk. 

261 

262 Usage: 

263 doc_ref_val = DocumentReferenceValue( 

264 'projects/my-proj/databases/(default)/documents/my-col/my-doc', 

265 ) 

266 assert doc_ref_val.project_name == 'my-proj' 

267 assert doc_ref_val.collection_name == 'my-col' 

268 assert doc_ref_val.document_id == 'my-doc' 

269 assert doc_ref_val.database_name == '(default)' 

270 

271 Raises: 

272 ValueError: If the supplied value cannot satisfy a complete path. 

273 """ 

274 

275 def __init__(self, reference_value: str): 

276 self._reference_value = reference_value 

277 

278 # The first 5 parts are 

279 # projects, {project}, databases, {database}, documents 

280 parts = reference_value.split(DOCUMENT_PATH_DELIMITER) 

281 if len(parts) < 7: 

282 msg = BAD_REFERENCE_ERROR.format(reference_value) 

283 raise ValueError(msg) 

284 

285 self.project_name = parts[1] 

286 self.collection_name = parts[5] 

287 self.database_name = parts[3] 

288 self.document_id = "/".join(parts[6:]) 

289 

290 @property 

291 def full_key(self) -> str: 

292 """Computed property for a DocumentReference's collection_name and 

293 document Id""" 

294 return "/".join([self.collection_name, self.document_id]) 

295 

296 @property 

297 def full_path(self) -> str: 

298 return self._reference_value or "/".join( 

299 [ 

300 "projects", 

301 self.project_name, 

302 "databases", 

303 self.database_name, 

304 "documents", 

305 self.collection_name, 

306 self.document_id, 

307 ] 

308 ) 

309 

310 

311def reference_value_to_document(reference_value, client) -> Any: 

312 """Convert a reference value string to a document. 

313 

314 Args: 

315 reference_value (str): A document reference value. 

316 client (:class:`~google.cloud.firestore_v1.client.Client`): 

317 A client that has a document factory. 

318 

319 Returns: 

320 :class:`~google.cloud.firestore_v1.document.DocumentReference`: 

321 The document corresponding to ``reference_value``. 

322 

323 Raises: 

324 ValueError: If the ``reference_value`` is not of the expected 

325 format: ``projects/{project}/databases/{database}/documents/...``. 

326 ValueError: If the ``reference_value`` does not come from the same 

327 project / database combination as the ``client``. 

328 """ 

329 from google.cloud.firestore_v1.base_document import BaseDocumentReference 

330 

331 doc_ref_value = DocumentReferenceValue(reference_value) 

332 

333 document: BaseDocumentReference = client.document(doc_ref_value.full_key) 

334 if document._document_path != reference_value: 

335 msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string) 

336 raise ValueError(msg) 

337 

338 return document 

339 

340 

341def decode_value( 

342 value, client 

343) -> Union[ 

344 None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint, Vector 

345]: 

346 """Converts a Firestore protobuf ``Value`` to a native Python value. 

347 

348 Args: 

349 value (google.cloud.firestore_v1.types.Value): A 

350 Firestore protobuf to be decoded / parsed / converted. 

351 client (:class:`~google.cloud.firestore_v1.client.Client`): 

352 A client that has a document factory. 

353 

354 Returns: 

355 Union[NoneType, bool, int, float, datetime.datetime, \ 

356 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]: A native 

357 Python value converted from the ``value``. 

358 

359 Raises: 

360 NotImplementedError: If the ``value_type`` is ``reference_value``. 

361 ValueError: If the ``value_type`` is unknown. 

362 """ 

363 value_pb = getattr(value, "_pb", value) 

364 value_type = value_pb.WhichOneof("value_type") 

365 

366 if value_type == "null_value": 

367 return None 

368 elif value_type == "boolean_value": 

369 return value_pb.boolean_value 

370 elif value_type == "integer_value": 

371 return value_pb.integer_value 

372 elif value_type == "double_value": 

373 return value_pb.double_value 

374 elif value_type == "timestamp_value": 

375 return DatetimeWithNanoseconds.from_timestamp_pb(value_pb.timestamp_value) 

376 elif value_type == "string_value": 

377 return value_pb.string_value 

378 elif value_type == "bytes_value": 

379 return value_pb.bytes_value 

380 elif value_type == "reference_value": 

381 return reference_value_to_document(value_pb.reference_value, client) 

382 elif value_type == "geo_point_value": 

383 return GeoPoint( 

384 value_pb.geo_point_value.latitude, value_pb.geo_point_value.longitude 

385 ) 

386 elif value_type == "array_value": 

387 return [ 

388 decode_value(element, client) for element in value_pb.array_value.values 

389 ] 

390 elif value_type == "map_value": 

391 return decode_dict(value_pb.map_value.fields, client) 

392 else: 

393 raise ValueError("Unknown ``value_type``", value_type) 

394 

395 

396def decode_dict(value_fields, client) -> Union[dict, Vector]: 

397 """Converts a protobuf map of Firestore ``Value``-s. 

398 

399 Args: 

400 value_fields (google.protobuf.pyext._message.MessageMapContainer): A 

401 protobuf map of Firestore ``Value``-s. 

402 client (:class:`~google.cloud.firestore_v1.client.Client`): 

403 A client that has a document factory. 

404 

405 Returns: 

406 Dict[str, Union[NoneType, bool, int, float, datetime.datetime, \ 

407 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]]: A dictionary 

408 of native Python values converted from the ``value_fields``. 

409 """ 

410 value_fields_pb = getattr(value_fields, "_pb", value_fields) 

411 res = {key: decode_value(value, client) for key, value in value_fields_pb.items()} 

412 

413 if res.get("__type__", None) == "__vector__": 

414 # Vector data type is represented as mapping. 

415 # {"__type__":"__vector__", "value": [1.0, 2.0, 3.0]}. 

416 values = cast(Sequence[float], res["value"]) 

417 return Vector(values) 

418 

419 return res 

420 

421 

422def get_doc_id(document_pb, expected_prefix) -> str: 

423 """Parse a document ID from a document protobuf. 

424 

425 Args: 

426 document_pb (google.cloud.firestore_v1.\ 

427 document.Document): A protobuf for a document that 

428 was created in a ``CreateDocument`` RPC. 

429 expected_prefix (str): The expected collection prefix for the 

430 fully-qualified document name. 

431 

432 Returns: 

433 str: The document ID from the protobuf. 

434 

435 Raises: 

436 ValueError: If the name does not begin with the prefix. 

437 """ 

438 prefix, document_id = document_pb.name.rsplit(DOCUMENT_PATH_DELIMITER, 1) 

439 if prefix != expected_prefix: 

440 raise ValueError( 

441 "Unexpected document name", 

442 document_pb.name, 

443 "Expected to begin with", 

444 expected_prefix, 

445 ) 

446 

447 return document_id 

448 

449 

450_EmptyDict = transforms.Sentinel("Marker for an empty dict value") 

451 

452 

453def extract_fields( 

454 document_data, prefix_path: FieldPath, expand_dots=False 

455) -> Generator[Tuple[Any, Any], Any, None]: 

456 """Do depth-first walk of tree, yielding field_path, value""" 

457 if not document_data: 

458 yield prefix_path, _EmptyDict 

459 else: 

460 for key, value in sorted(document_data.items()): 

461 if expand_dots: 

462 sub_key = FieldPath.from_string(key) 

463 else: 

464 sub_key = FieldPath(key) 

465 

466 field_path = FieldPath(*(prefix_path.parts + sub_key.parts)) 

467 

468 if isinstance(value, dict): 

469 for s_path, s_value in extract_fields(value, field_path): 

470 yield s_path, s_value 

471 else: 

472 yield field_path, value 

473 

474 

475def set_field_value(document_data, field_path, value) -> None: 

476 """Set a value into a document for a field_path""" 

477 current = document_data 

478 for element in field_path.parts[:-1]: 

479 current = current.setdefault(element, {}) 

480 if value is _EmptyDict: 

481 value = {} 

482 current[field_path.parts[-1]] = value 

483 

484 

485def get_field_value(document_data, field_path) -> Any: 

486 if not field_path.parts: 

487 raise ValueError("Empty path") 

488 

489 current = document_data 

490 for element in field_path.parts[:-1]: 

491 current = current[element] 

492 return current[field_path.parts[-1]] 

493 

494 

495class DocumentExtractor(object): 

496 """Break document data up into actual data and transforms. 

497 

498 Handle special values such as ``DELETE_FIELD``, ``SERVER_TIMESTAMP``. 

499 

500 Args: 

501 document_data (dict): 

502 Property names and values to use for sending a change to 

503 a document. 

504 """ 

505 

506 def __init__(self, document_data) -> None: 

507 self.document_data = document_data 

508 self.field_paths = [] 

509 self.deleted_fields = [] 

510 self.server_timestamps = [] 

511 self.array_removes = {} 

512 self.array_unions = {} 

513 self.increments = {} 

514 self.minimums = {} 

515 self.maximums = {} 

516 self.set_fields: dict = {} 

517 self.empty_document = False 

518 

519 prefix_path = FieldPath() 

520 iterator = self._get_document_iterator(prefix_path) 

521 

522 for field_path, value in iterator: 

523 if field_path == prefix_path and value is _EmptyDict: 

524 self.empty_document = True 

525 

526 elif value is transforms.DELETE_FIELD: 

527 self.deleted_fields.append(field_path) 

528 

529 elif value is transforms.SERVER_TIMESTAMP: 

530 self.server_timestamps.append(field_path) 

531 

532 elif isinstance(value, transforms.ArrayRemove): 

533 self.array_removes[field_path] = value.values 

534 

535 elif isinstance(value, transforms.ArrayUnion): 

536 self.array_unions[field_path] = value.values 

537 

538 elif isinstance(value, transforms.Increment): 

539 self.increments[field_path] = value.value 

540 

541 elif isinstance(value, transforms.Maximum): 

542 self.maximums[field_path] = value.value 

543 

544 elif isinstance(value, transforms.Minimum): 

545 self.minimums[field_path] = value.value 

546 

547 else: 

548 self.field_paths.append(field_path) 

549 set_field_value(self.set_fields, field_path, value) 

550 

551 def _get_document_iterator( 

552 self, prefix_path: FieldPath 

553 ) -> Generator[Tuple[Any, Any], Any, None]: 

554 return extract_fields(self.document_data, prefix_path) 

555 

556 @property 

557 def has_transforms(self): 

558 return bool( 

559 self.server_timestamps 

560 or self.array_removes 

561 or self.array_unions 

562 or self.increments 

563 or self.maximums 

564 or self.minimums 

565 ) 

566 

567 @property 

568 def transform_paths(self): 

569 return sorted( 

570 self.server_timestamps 

571 + list(self.array_removes) 

572 + list(self.array_unions) 

573 + list(self.increments) 

574 + list(self.maximums) 

575 + list(self.minimums) 

576 ) 

577 

578 def _get_update_mask( 

579 self, allow_empty_mask=False 

580 ) -> Optional[types.common.DocumentMask]: 

581 return None 

582 

583 def get_update_pb( 

584 self, document_path, exists=None, allow_empty_mask=False 

585 ) -> types.write.Write: 

586 if exists is not None: 

587 current_document = common.Precondition(exists=exists) 

588 else: 

589 current_document = None 

590 

591 update_pb = write.Write( 

592 update=document.Document( 

593 name=document_path, fields=encode_dict(self.set_fields) 

594 ), 

595 update_mask=self._get_update_mask(allow_empty_mask), 

596 current_document=current_document, 

597 ) 

598 

599 return update_pb 

600 

601 def get_field_transform_pbs( 

602 self, document_path 

603 ) -> List[types.write.DocumentTransform.FieldTransform]: 

604 def make_array_value(values): 

605 value_list = [encode_value(element) for element in values] 

606 return document.ArrayValue(values=value_list) 

607 

608 path_field_transforms = ( 

609 [ 

610 ( 

611 path, 

612 write.DocumentTransform.FieldTransform( 

613 field_path=path.to_api_repr(), 

614 set_to_server_value=REQUEST_TIME_ENUM, 

615 ), 

616 ) 

617 for path in self.server_timestamps 

618 ] 

619 + [ 

620 ( 

621 path, 

622 write.DocumentTransform.FieldTransform( 

623 field_path=path.to_api_repr(), 

624 remove_all_from_array=make_array_value(values), 

625 ), 

626 ) 

627 for path, values in self.array_removes.items() 

628 ] 

629 + [ 

630 ( 

631 path, 

632 write.DocumentTransform.FieldTransform( 

633 field_path=path.to_api_repr(), 

634 append_missing_elements=make_array_value(values), 

635 ), 

636 ) 

637 for path, values in self.array_unions.items() 

638 ] 

639 + [ 

640 ( 

641 path, 

642 write.DocumentTransform.FieldTransform( 

643 field_path=path.to_api_repr(), increment=encode_value(value) 

644 ), 

645 ) 

646 for path, value in self.increments.items() 

647 ] 

648 + [ 

649 ( 

650 path, 

651 write.DocumentTransform.FieldTransform( 

652 field_path=path.to_api_repr(), maximum=encode_value(value) 

653 ), 

654 ) 

655 for path, value in self.maximums.items() 

656 ] 

657 + [ 

658 ( 

659 path, 

660 write.DocumentTransform.FieldTransform( 

661 field_path=path.to_api_repr(), minimum=encode_value(value) 

662 ), 

663 ) 

664 for path, value in self.minimums.items() 

665 ] 

666 ) 

667 return [transform for path, transform in sorted(path_field_transforms)] 

668 

669 def get_transform_pb(self, document_path, exists=None) -> types.write.Write: 

670 field_transforms = self.get_field_transform_pbs(document_path) 

671 transform_pb = write.Write( 

672 transform=write.DocumentTransform( 

673 document=document_path, field_transforms=field_transforms 

674 ) 

675 ) 

676 if exists is not None: 

677 transform_pb._pb.current_document.CopyFrom( 

678 common.Precondition(exists=exists)._pb 

679 ) 

680 

681 return transform_pb 

682 

683 

684def pbs_for_create(document_path, document_data) -> List[types.write.Write]: 

685 """Make ``Write`` protobufs for ``create()`` methods. 

686 

687 Args: 

688 document_path (str): A fully-qualified document path. 

689 document_data (dict): Property names and values to use for 

690 creating a document. 

691 

692 Returns: 

693 List[google.cloud.firestore_v1.types.Write]: One or two 

694 ``Write`` protobuf instances for ``create()``. 

695 """ 

696 extractor = DocumentExtractor(document_data) 

697 

698 if extractor.deleted_fields: 

699 raise ValueError("Cannot apply DELETE_FIELD in a create request.") 

700 

701 create_pb = extractor.get_update_pb(document_path, exists=False) 

702 

703 if extractor.has_transforms: 

704 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

705 create_pb.update_transforms.extend(field_transform_pbs) 

706 

707 return [create_pb] 

708 

709 

710def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]: 

711 """Make ``Write`` protobufs for ``set()`` methods. 

712 

713 Args: 

714 document_path (str): A fully-qualified document path. 

715 document_data (dict): Property names and values to use for 

716 replacing a document. 

717 

718 Returns: 

719 List[google.cloud.firestore_v1.types.Write]: One 

720 or two ``Write`` protobuf instances for ``set()``. 

721 """ 

722 extractor = DocumentExtractor(document_data) 

723 

724 if extractor.deleted_fields: 

725 raise ValueError( 

726 "Cannot apply DELETE_FIELD in a set request without " 

727 "specifying 'merge=True' or 'merge=[field_paths]'." 

728 ) 

729 

730 set_pb = extractor.get_update_pb(document_path) 

731 

732 if extractor.has_transforms: 

733 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

734 set_pb.update_transforms.extend(field_transform_pbs) 

735 

736 return [set_pb] 

737 

738 

739class DocumentExtractorForMerge(DocumentExtractor): 

740 """Break document data up into actual data and transforms.""" 

741 

742 def __init__(self, document_data) -> None: 

743 super(DocumentExtractorForMerge, self).__init__(document_data) 

744 self.data_merge: list = [] 

745 self.transform_merge: list = [] 

746 self.merge: list = [] 

747 

748 def _apply_merge_all(self) -> None: 

749 self.data_merge = sorted(self.field_paths + self.deleted_fields) 

750 # TODO: other transforms 

751 self.transform_merge = self.transform_paths 

752 self.merge = sorted(self.data_merge + self.transform_paths) 

753 

754 def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]: 

755 for merge_field in merge: 

756 if isinstance(merge_field, FieldPath): 

757 yield merge_field 

758 else: 

759 yield FieldPath(*parse_field_path(merge_field)) 

760 

761 def _normalize_merge_paths(self, merge) -> list: 

762 merge_paths = sorted(self._construct_merge_paths(merge)) 

763 

764 # Raise if any merge path is a parent of another. Leverage sorting 

765 # to avoid quadratic behavior. 

766 for index in range(len(merge_paths) - 1): 

767 lhs, rhs = merge_paths[index], merge_paths[index + 1] 

768 if lhs.eq_or_parent(rhs): 

769 raise ValueError("Merge paths overlap: {}, {}".format(lhs, rhs)) 

770 

771 for merge_path in merge_paths: 

772 if merge_path in self.deleted_fields: 

773 continue 

774 try: 

775 get_field_value(self.document_data, merge_path) 

776 except KeyError: 

777 raise ValueError("Invalid merge path: {}".format(merge_path)) 

778 

779 return merge_paths 

780 

781 def _apply_merge_paths(self, merge) -> None: 

782 if self.empty_document: 

783 raise ValueError("Cannot merge specific fields with empty document.") 

784 

785 merge_paths = self._normalize_merge_paths(merge) 

786 

787 del self.data_merge[:] 

788 del self.transform_merge[:] 

789 self.merge = merge_paths 

790 

791 for merge_path in merge_paths: 

792 if merge_path in self.transform_paths: 

793 self.transform_merge.append(merge_path) 

794 

795 for field_path in self.field_paths: 

796 if merge_path.eq_or_parent(field_path): 

797 self.data_merge.append(field_path) 

798 

799 # Clear out data for fields not merged. 

800 merged_set_fields: dict = {} 

801 for field_path in self.data_merge: 

802 value = get_field_value(self.document_data, field_path) 

803 set_field_value(merged_set_fields, field_path, value) 

804 self.set_fields = merged_set_fields 

805 

806 unmerged_deleted_fields = [ 

807 field_path 

808 for field_path in self.deleted_fields 

809 if field_path not in self.merge 

810 ] 

811 if unmerged_deleted_fields: 

812 raise ValueError( 

813 "Cannot delete unmerged fields: {}".format(unmerged_deleted_fields) 

814 ) 

815 self.data_merge = sorted(self.data_merge + self.deleted_fields) 

816 

817 # Keep only transforms which are within merge. 

818 merged_transform_paths = set() 

819 for merge_path in self.merge: 

820 tranform_merge_paths = [ 

821 transform_path 

822 for transform_path in self.transform_paths 

823 if merge_path.eq_or_parent(transform_path) 

824 ] 

825 merged_transform_paths.update(tranform_merge_paths) 

826 

827 self.server_timestamps = [ 

828 path for path in self.server_timestamps if path in merged_transform_paths 

829 ] 

830 

831 self.array_removes = { 

832 path: values 

833 for path, values in self.array_removes.items() 

834 if path in merged_transform_paths 

835 } 

836 

837 self.array_unions = { 

838 path: values 

839 for path, values in self.array_unions.items() 

840 if path in merged_transform_paths 

841 } 

842 

843 def apply_merge(self, merge) -> None: 

844 if merge is True: # merge all fields 

845 self._apply_merge_all() 

846 else: 

847 self._apply_merge_paths(merge) 

848 

849 def _get_update_mask( 

850 self, allow_empty_mask=False 

851 ) -> Optional[types.common.DocumentMask]: 

852 # Mask uses dotted / quoted paths. 

853 mask_paths = [ 

854 field_path.to_api_repr() 

855 for field_path in self.merge 

856 if field_path not in self.transform_merge 

857 ] 

858 

859 return common.DocumentMask(field_paths=mask_paths) 

860 

861 

862def pbs_for_set_with_merge( 

863 document_path, document_data, merge 

864) -> List[types.write.Write]: 

865 """Make ``Write`` protobufs for ``set()`` methods. 

866 

867 Args: 

868 document_path (str): A fully-qualified document path. 

869 document_data (dict): Property names and values to use for 

870 replacing a document. 

871 merge (Optional[bool] or Optional[List<apispec>]): 

872 If True, merge all fields; else, merge only the named fields. 

873 

874 Returns: 

875 List[google.cloud.firestore_v1.types.Write]: One 

876 or two ``Write`` protobuf instances for ``set()``. 

877 """ 

878 extractor = DocumentExtractorForMerge(document_data) 

879 extractor.apply_merge(merge) 

880 

881 set_pb = extractor.get_update_pb(document_path) 

882 

883 if extractor.transform_paths: 

884 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

885 set_pb.update_transforms.extend(field_transform_pbs) 

886 

887 return [set_pb] 

888 

889 

890class DocumentExtractorForUpdate(DocumentExtractor): 

891 """Break document data up into actual data and transforms.""" 

892 

893 def __init__(self, document_data) -> None: 

894 super(DocumentExtractorForUpdate, self).__init__(document_data) 

895 self.top_level_paths = sorted( 

896 [FieldPath.from_string(key) for key in document_data] 

897 ) 

898 tops = set(self.top_level_paths) 

899 for top_level_path in self.top_level_paths: 

900 for ancestor in top_level_path.lineage(): 

901 if ancestor in tops: 

902 raise ValueError( 

903 "Conflicting field path: {}, {}".format( 

904 top_level_path, ancestor 

905 ) 

906 ) 

907 

908 for field_path in self.deleted_fields: 

909 if field_path not in tops: 

910 raise ValueError( 

911 "Cannot update with nest delete: {}".format(field_path) 

912 ) 

913 

914 def _get_document_iterator( 

915 self, prefix_path: FieldPath 

916 ) -> Generator[Tuple[Any, Any], Any, None]: 

917 return extract_fields(self.document_data, prefix_path, expand_dots=True) 

918 

919 def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask: 

920 mask_paths = [] 

921 for field_path in self.top_level_paths: 

922 if field_path not in self.transform_paths: 

923 mask_paths.append(field_path.to_api_repr()) 

924 

925 return common.DocumentMask(field_paths=mask_paths) 

926 

927 

928def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]: 

929 """Make ``Write`` protobufs for ``update()`` methods. 

930 

931 Args: 

932 document_path (str): A fully-qualified document path. 

933 field_updates (dict): Field names or paths to update and values 

934 to update with. 

935 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): 

936 A write option to make assertions / preconditions on the server 

937 state of the document before applying changes. 

938 

939 Returns: 

940 List[google.cloud.firestore_v1.types.Write]: One 

941 or two ``Write`` protobuf instances for ``update()``. 

942 """ 

943 extractor = DocumentExtractorForUpdate(field_updates) 

944 

945 if extractor.empty_document: 

946 raise ValueError("Cannot update with an empty document.") 

947 

948 if option is None: # Default is to use ``exists=True``. 

949 option = ExistsOption(exists=True) 

950 

951 update_pb = extractor.get_update_pb(document_path) 

952 option.modify_write(update_pb) 

953 

954 if extractor.has_transforms: 

955 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

956 update_pb.update_transforms.extend(field_transform_pbs) 

957 

958 return [update_pb] 

959 

960 

961def pb_for_delete(document_path, option) -> types.write.Write: 

962 """Make a ``Write`` protobuf for ``delete()`` methods. 

963 

964 Args: 

965 document_path (str): A fully-qualified document path. 

966 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): 

967 A write option to make assertions / preconditions on the server 

968 state of the document before applying changes. 

969 

970 Returns: 

971 google.cloud.firestore_v1.types.Write: A 

972 ``Write`` protobuf instance for the ``delete()``. 

973 """ 

974 write_pb = write.Write(delete=document_path) 

975 if option is not None: 

976 option.modify_write(write_pb) 

977 

978 return write_pb 

979 

980 

981class ReadAfterWriteError(Exception): 

982 """Raised when a read is attempted after a write. 

983 

984 Raised by "read" methods that use transactions. 

985 """ 

986 

987 

988def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]: 

989 """Get the transaction ID from a ``Transaction`` object. 

990 

991 Args: 

992 transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ 

993 Transaction`]): 

994 An existing transaction that this query will run in. 

995 read_operation (Optional[bool]): Indicates if the transaction ID 

996 will be used in a read operation. Defaults to :data:`True`. 

997 

998 Returns: 

999 Optional[bytes]: The ID of the transaction, or :data:`None` if the 

1000 ``transaction`` is :data:`None`. 

1001 

1002 Raises: 

1003 ValueError: If the ``transaction`` is not in progress (only if 

1004 ``transaction`` is not :data:`None`). 

1005 ReadAfterWriteError: If the ``transaction`` has writes stored on 

1006 it and ``read_operation`` is :data:`True`. 

1007 """ 

1008 if transaction is None: 

1009 return None 

1010 else: 

1011 if not transaction.in_progress: 

1012 raise ValueError(INACTIVE_TXN) 

1013 if read_operation and len(transaction._write_pbs) > 0: 

1014 raise ReadAfterWriteError(READ_AFTER_WRITE_ERROR) 

1015 return transaction.id 

1016 

1017 

1018def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: 

1019 """Create RPC metadata containing a prefix. 

1020 

1021 Args: 

1022 prefix (str): appropriate resource path. 

1023 

1024 Returns: 

1025 List[Tuple[str, str]]: RPC metadata with supplied prefix 

1026 """ 

1027 return [("google-cloud-resource-prefix", prefix)] 

1028 

1029 

1030class WriteOption(object): 

1031 """Option used to assert a condition on a write operation.""" 

1032 

1033 def modify_write(self, write, no_create_msg=None) -> None: 

1034 """Modify a ``Write`` protobuf based on the state of this write option. 

1035 

1036 This is a virtual method intended to be implemented by subclasses. 

1037 

1038 Args: 

1039 write (google.cloud.firestore_v1.types.Write): A 

1040 ``Write`` protobuf instance to be modified with a precondition 

1041 determined by the state of this option. 

1042 no_create_msg (Optional[str]): A message to use to indicate that 

1043 a create operation is not allowed. 

1044 

1045 Raises: 

1046 NotImplementedError: Always, this method is virtual. 

1047 """ 

1048 raise NotImplementedError 

1049 

1050 

1051class LastUpdateOption(WriteOption): 

1052 """Option used to assert a "last update" condition on a write operation. 

1053 

1054 This will typically be created by 

1055 :meth:`~google.cloud.firestore_v1.client.Client.write_option`. 

1056 

1057 Args: 

1058 last_update_time (google.protobuf.timestamp_pb2.Timestamp): A 

1059 timestamp. When set, the target document must exist and have 

1060 been last updated at that time. Protobuf ``update_time`` timestamps 

1061 are typically returned from methods that perform write operations 

1062 as part of a "write result" protobuf or directly. 

1063 """ 

1064 

1065 def __init__(self, last_update_time) -> None: 

1066 self._last_update_time = last_update_time 

1067 

1068 def __eq__(self, other): 

1069 if not isinstance(other, self.__class__): 

1070 return NotImplemented 

1071 return self._last_update_time == other._last_update_time 

1072 

1073 def modify_write(self, write, *unused_args, **unused_kwargs) -> None: 

1074 """Modify a ``Write`` protobuf based on the state of this write option. 

1075 

1076 The ``last_update_time`` is added to ``write_pb`` as an "update time" 

1077 precondition. When set, the target document must exist and have been 

1078 last updated at that time. 

1079 

1080 Args: 

1081 write_pb (google.cloud.firestore_v1.types.Write): A 

1082 ``Write`` protobuf instance to be modified with a precondition 

1083 determined by the state of this option. 

1084 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by 

1085 other subclasses that are unused here. 

1086 """ 

1087 current_doc = types.Precondition(update_time=self._last_update_time) 

1088 write._pb.current_document.CopyFrom(current_doc._pb) 

1089 

1090 

1091class ExistsOption(WriteOption): 

1092 """Option used to assert existence on a write operation. 

1093 

1094 This will typically be created by 

1095 :meth:`~google.cloud.firestore_v1.client.Client.write_option`. 

1096 

1097 Args: 

1098 exists (bool): Indicates if the document being modified 

1099 should already exist. 

1100 """ 

1101 

1102 def __init__(self, exists) -> None: 

1103 self._exists = exists 

1104 

1105 def __eq__(self, other): 

1106 if not isinstance(other, self.__class__): 

1107 return NotImplemented 

1108 return self._exists == other._exists 

1109 

1110 def modify_write(self, write, *unused_args, **unused_kwargs) -> None: 

1111 """Modify a ``Write`` protobuf based on the state of this write option. 

1112 

1113 If: 

1114 

1115 * ``exists=True``, adds a precondition that requires existence 

1116 * ``exists=False``, adds a precondition that requires non-existence 

1117 

1118 Args: 

1119 write (google.cloud.firestore_v1.types.Write): A 

1120 ``Write`` protobuf instance to be modified with a precondition 

1121 determined by the state of this option. 

1122 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by 

1123 other subclasses that are unused here. 

1124 """ 

1125 current_doc = types.Precondition(exists=self._exists) 

1126 write._pb.current_document.CopyFrom(current_doc._pb) 

1127 

1128 

1129def make_retry_timeout_kwargs( 

1130 retry: retries.Retry | retries.AsyncRetry | object | None, timeout: float | None 

1131) -> dict: 

1132 """Helper fo API methods which take optional 'retry' / 'timeout' args.""" 

1133 kwargs = {} 

1134 

1135 if retry is not gapic_v1.method.DEFAULT: 

1136 kwargs["retry"] = retry 

1137 

1138 if timeout is not None: 

1139 kwargs["timeout"] = timeout 

1140 

1141 return kwargs 

1142 

1143 

1144def build_timestamp( 

1145 dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None 

1146) -> Timestamp: 

1147 """Returns the supplied datetime (or "now") as a Timestamp""" 

1148 return _datetime_to_pb_timestamp( 

1149 dt or DatetimeWithNanoseconds.now(tz=datetime.timezone.utc) 

1150 ) 

1151 

1152 

1153def compare_timestamps( 

1154 ts1: Union[Timestamp, datetime.datetime], 

1155 ts2: Union[Timestamp, datetime.datetime], 

1156) -> int: 

1157 ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1 

1158 ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2 

1159 ts1_nanos = ts1.nanos + ts1.seconds * 1e9 

1160 ts2_nanos = ts2.nanos + ts2.seconds * 1e9 

1161 if ts1_nanos == ts2_nanos: 

1162 return 0 

1163 return 1 if ts1_nanos > ts2_nanos else -1 

1164 

1165 

1166def deserialize_bundle( 

1167 serialized: Union[str, bytes], 

1168 client: "google.cloud.firestore_v1.client.BaseClient", 

1169) -> "google.cloud.firestore_bundle.FirestoreBundle": 

1170 """Inverse operation to a `FirestoreBundle` instance's `build()` method. 

1171 

1172 Args: 

1173 serialized (Union[str, bytes]): The result of `FirestoreBundle.build()`. 

1174 Should be a list of dictionaries in string format. 

1175 client (BaseClient): A connected Client instance. 

1176 

1177 Returns: 

1178 FirestoreBundle: A bundle equivalent to that which called `build()` and 

1179 initially created the `serialized` value. 

1180 

1181 Raises: 

1182 ValueError: If any of the dictionaries in the list contain any more than 

1183 one top-level key. 

1184 ValueError: If any unexpected BundleElement types are encountered. 

1185 ValueError: If the serialized bundle ends before expected. 

1186 """ 

1187 from google.cloud.firestore_bundle import BundleElement, FirestoreBundle 

1188 

1189 # Outlines the legal transitions from one BundleElement to another. 

1190 bundle_state_machine = { 

1191 "__initial__": ["metadata"], 

1192 "metadata": ["namedQuery", "documentMetadata", "__end__"], 

1193 "namedQuery": ["namedQuery", "documentMetadata", "__end__"], 

1194 "documentMetadata": ["document"], 

1195 "document": ["documentMetadata", "__end__"], 

1196 } 

1197 allowed_next_element_types: List[str] = bundle_state_machine["__initial__"] 

1198 

1199 # This must be saved and added last, since we cache it to preserve timestamps, 

1200 # yet must flush it whenever a new document or query is added to a bundle. 

1201 # The process of deserializing a bundle uses these methods which flush a 

1202 # cached metadata element, and thus, it must be the last BundleElement 

1203 # added during deserialization. 

1204 metadata_bundle_element: Optional[BundleElement] = None 

1205 

1206 bundle: Optional[FirestoreBundle] = None 

1207 data: Dict 

1208 for data in _parse_bundle_elements_data(serialized): 

1209 # BundleElements are serialized as JSON containing one key outlining 

1210 # the type, with all further data nested under that key 

1211 keys: List[str] = list(data.keys()) 

1212 

1213 if len(keys) != 1: 

1214 raise ValueError("Expected serialized BundleElement with one top-level key") 

1215 

1216 key: str = keys[0] 

1217 

1218 if key not in allowed_next_element_types: 

1219 raise ValueError( 

1220 f"Encountered BundleElement of type {key}. " 

1221 f"Expected one of {allowed_next_element_types}" 

1222 ) 

1223 

1224 # Create and add our BundleElement 

1225 bundle_element: BundleElement 

1226 try: 

1227 bundle_element = BundleElement.from_json(json.dumps(data)) 

1228 except AttributeError as e: 

1229 # Some bad serialization formats cannot be universally deserialized. 

1230 if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER 

1231 raise ValueError( 

1232 "Invalid serialization of datetimes. " 

1233 "Cannot deserialize Bundles created from the NodeJS SDK." 

1234 ) 

1235 raise e # pragma: NO COVER 

1236 

1237 if bundle is None: 

1238 # This must be the first bundle type encountered 

1239 assert key == "metadata" 

1240 bundle = FirestoreBundle(data[key]["id"]) 

1241 metadata_bundle_element = bundle_element 

1242 

1243 else: 

1244 bundle._add_bundle_element(bundle_element, client=client, type=key) 

1245 

1246 # Update the allowed next BundleElement types 

1247 allowed_next_element_types = bundle_state_machine[key] 

1248 

1249 if "__end__" not in allowed_next_element_types: 

1250 raise ValueError("Unexpected end to serialized FirestoreBundle") 

1251 # state machine guarantees bundle and metadata have been populated 

1252 bundle = cast(FirestoreBundle, bundle) 

1253 metadata_bundle_element = cast(BundleElement, metadata_bundle_element) 

1254 # Now, finally add the metadata element 

1255 bundle._add_bundle_element( 

1256 metadata_bundle_element, 

1257 client=client, 

1258 type="metadata", 

1259 ) 

1260 

1261 return bundle 

1262 

1263 

1264def _parse_bundle_elements_data( 

1265 serialized: Union[str, bytes] 

1266) -> Generator[Dict, None, None]: 

1267 """Reads through a serialized FirestoreBundle and yields JSON chunks that 

1268 were created via `BundleElement.to_json(bundle_element)`. 

1269 

1270 Serialized FirestoreBundle instances are length-prefixed JSON objects, and 

1271 so are of the form "123{...}57{...}" 

1272 To correctly and safely read a bundle, we must first detect these length 

1273 prefixes, read that many bytes of data, and attempt to JSON-parse that. 

1274 

1275 Raises: 

1276 ValueError: If a chunk of JSON ever starts without following a length 

1277 prefix. 

1278 """ 

1279 _serialized: Iterator[int] = iter( 

1280 serialized if isinstance(serialized, bytes) else serialized.encode("utf-8") 

1281 ) 

1282 

1283 length_prefix: str = "" 

1284 while True: 

1285 byte: Optional[int] = next(_serialized, None) 

1286 

1287 if byte is None: 

1288 return None 

1289 

1290 _str: str = chr(byte) 

1291 if _str.isnumeric(): 

1292 length_prefix += _str 

1293 else: 

1294 if length_prefix == "": 

1295 raise ValueError("Expected length prefix") 

1296 

1297 _length_prefix = int(length_prefix) 

1298 length_prefix = "" 

1299 _bytes = bytearray([byte]) 

1300 _counter = 1 

1301 while _counter < _length_prefix: 

1302 _bytes.append(next(_serialized)) 

1303 _counter += 1 

1304 

1305 yield json.loads(_bytes.decode("utf-8")) 

1306 

1307 

1308def _get_documents_from_bundle( 

1309 bundle, *, query_name: Optional[str] = None 

1310) -> Generator["DocumentSnapshot", None, None]: 

1311 from google.cloud.firestore_bundle.bundle import _BundledDocument 

1312 

1313 bundled_doc: _BundledDocument 

1314 for bundled_doc in bundle.documents.values(): 

1315 if query_name and query_name not in bundled_doc.metadata.queries: 

1316 continue 

1317 yield bundled_doc.snapshot 

1318 

1319 

1320def _get_document_from_bundle( 

1321 bundle, 

1322 *, 

1323 document_id: str, 

1324) -> Optional["DocumentSnapshot"]: 

1325 bundled_doc = bundle.documents.get(document_id) 

1326 if bundled_doc: 

1327 return bundled_doc.snapshot 

1328 else: 

1329 return None