Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/google/cloud/firestore_v1/_helpers.py: 23%

479 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-09 06:27 +0000

1# Copyright 2017 Google LLC All rights reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""Common helpers shared across Google Cloud Firestore modules.""" 

16 

17import datetime 

18import json 

19 

20import google 

21from google.api_core.datetime_helpers import DatetimeWithNanoseconds 

22from google.api_core import gapic_v1 

23from google.protobuf import struct_pb2 

24from google.type import latlng_pb2 # type: ignore 

25import grpc # type: ignore 

26 

27from google.cloud import exceptions # type: ignore 

28from google.cloud._helpers import _datetime_to_pb_timestamp # type: ignore 

29from google.cloud.firestore_v1.types.write import DocumentTransform 

30from google.cloud.firestore_v1 import transforms 

31from google.cloud.firestore_v1 import types 

32from google.cloud.firestore_v1.field_path import FieldPath 

33from google.cloud.firestore_v1.field_path import parse_field_path 

34from google.cloud.firestore_v1.types import common 

35from google.cloud.firestore_v1.types import document 

36from google.cloud.firestore_v1.types import write 

37from google.protobuf.timestamp_pb2 import Timestamp # type: ignore 

38from typing import ( 

39 Any, 

40 Dict, 

41 Generator, 

42 Iterator, 

43 List, 

44 NoReturn, 

45 Optional, 

46 Tuple, 

47 Union, 

48) 

49 

50_EmptyDict: transforms.Sentinel 

51_GRPC_ERROR_MAPPING: dict 

52 

53 

54BAD_PATH_TEMPLATE = "A path element must be a string. Received {}, which is a {}." 

55DOCUMENT_PATH_DELIMITER = "/" 

56INACTIVE_TXN = "Transaction not in progress, cannot be used in API requests." 

57READ_AFTER_WRITE_ERROR = "Attempted read after write in a transaction." 

58BAD_REFERENCE_ERROR = ( 

59 "Reference value {!r} in unexpected format, expected to be of the form " 

60 "``projects/{{project}}/databases/{{database}}/" 

61 "documents/{{document_path}}``." 

62) 

63WRONG_APP_REFERENCE = ( 

64 "Document {!r} does not correspond to the same database " "({!r}) as the client." 

65) 

66REQUEST_TIME_ENUM = DocumentTransform.FieldTransform.ServerValue.REQUEST_TIME 

67_GRPC_ERROR_MAPPING = { 

68 grpc.StatusCode.ALREADY_EXISTS: exceptions.Conflict, 

69 grpc.StatusCode.NOT_FOUND: exceptions.NotFound, 

70} 

71 

72 

73class GeoPoint(object): 

74 """Simple container for a geo point value. 

75 

76 Args: 

77 latitude (float): Latitude of a point. 

78 longitude (float): Longitude of a point. 

79 """ 

80 

81 def __init__(self, latitude, longitude) -> None: 

82 self.latitude = latitude 

83 self.longitude = longitude 

84 

85 def to_protobuf(self) -> latlng_pb2.LatLng: 

86 """Convert the current object to protobuf. 

87 

88 Returns: 

89 google.type.latlng_pb2.LatLng: The current point as a protobuf. 

90 """ 

91 return latlng_pb2.LatLng(latitude=self.latitude, longitude=self.longitude) 

92 

93 def __eq__(self, other): 

94 """Compare two geo points for equality. 

95 

96 Returns: 

97 Union[bool, NotImplemented]: :data:`True` if the points compare 

98 equal, else :data:`False`. (Or :data:`NotImplemented` if 

99 ``other`` is not a geo point.) 

100 """ 

101 if not isinstance(other, GeoPoint): 

102 return NotImplemented 

103 

104 return self.latitude == other.latitude and self.longitude == other.longitude 

105 

106 def __ne__(self, other): 

107 """Compare two geo points for inequality. 

108 

109 Returns: 

110 Union[bool, NotImplemented]: :data:`False` if the points compare 

111 equal, else :data:`True`. (Or :data:`NotImplemented` if 

112 ``other`` is not a geo point.) 

113 """ 

114 equality_val = self.__eq__(other) 

115 if equality_val is NotImplemented: 

116 return NotImplemented 

117 else: 

118 return not equality_val 

119 

120 

121def verify_path(path, is_collection) -> None: 

122 """Verifies that a ``path`` has the correct form. 

123 

124 Checks that all of the elements in ``path`` are strings. 

125 

126 Args: 

127 path (Tuple[str, ...]): The components in a collection or 

128 document path. 

129 is_collection (bool): Indicates if the ``path`` represents 

130 a document or a collection. 

131 

132 Raises: 

133 ValueError: if 

134 

135 * the ``path`` is empty 

136 * ``is_collection=True`` and there are an even number of elements 

137 * ``is_collection=False`` and there are an odd number of elements 

138 * an element is not a string 

139 """ 

140 num_elements = len(path) 

141 if num_elements == 0: 

142 raise ValueError("Document or collection path cannot be empty") 

143 

144 if is_collection: 

145 if num_elements % 2 == 0: 

146 raise ValueError("A collection must have an odd number of path elements") 

147 

148 else: 

149 if num_elements % 2 == 1: 

150 raise ValueError("A document must have an even number of path elements") 

151 

152 for element in path: 

153 if not isinstance(element, str): 

154 msg = BAD_PATH_TEMPLATE.format(element, type(element)) 

155 raise ValueError(msg) 

156 

157 

158def encode_value(value) -> types.document.Value: 

159 """Converts a native Python value into a Firestore protobuf ``Value``. 

160 

161 Args: 

162 value (Union[NoneType, bool, int, float, datetime.datetime, \ 

163 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]): A native 

164 Python value to convert to a protobuf field. 

165 

166 Returns: 

167 ~google.cloud.firestore_v1.types.Value: A 

168 value encoded as a Firestore protobuf. 

169 

170 Raises: 

171 TypeError: If the ``value`` is not one of the accepted types. 

172 """ 

173 if value is None: 

174 return document.Value(null_value=struct_pb2.NULL_VALUE) 

175 

176 # Must come before int since ``bool`` is an integer subtype. 

177 if isinstance(value, bool): 

178 return document.Value(boolean_value=value) 

179 

180 if isinstance(value, int): 

181 return document.Value(integer_value=value) 

182 

183 if isinstance(value, float): 

184 return document.Value(double_value=value) 

185 

186 if isinstance(value, DatetimeWithNanoseconds): 

187 return document.Value(timestamp_value=value.timestamp_pb()) 

188 

189 if isinstance(value, datetime.datetime): 

190 return document.Value(timestamp_value=_datetime_to_pb_timestamp(value)) 

191 

192 if isinstance(value, str): 

193 return document.Value(string_value=value) 

194 

195 if isinstance(value, bytes): 

196 return document.Value(bytes_value=value) 

197 

198 # NOTE: We avoid doing an isinstance() check for a Document 

199 # here to avoid import cycles. 

200 document_path = getattr(value, "_document_path", None) 

201 if document_path is not None: 

202 return document.Value(reference_value=document_path) 

203 

204 if isinstance(value, GeoPoint): 

205 return document.Value(geo_point_value=value.to_protobuf()) 

206 

207 if isinstance(value, (list, tuple, set, frozenset)): 

208 value_list = tuple(encode_value(element) for element in value) 

209 value_pb = document.ArrayValue(values=value_list) 

210 return document.Value(array_value=value_pb) 

211 

212 if isinstance(value, dict): 

213 value_dict = encode_dict(value) 

214 value_pb = document.MapValue(fields=value_dict) 

215 return document.Value(map_value=value_pb) 

216 

217 raise TypeError( 

218 "Cannot convert to a Firestore Value", value, "Invalid type", type(value) 

219 ) 

220 

221 

222def encode_dict(values_dict) -> dict: 

223 """Encode a dictionary into protobuf ``Value``-s. 

224 

225 Args: 

226 values_dict (dict): The dictionary to encode as protobuf fields. 

227 

228 Returns: 

229 Dict[str, ~google.cloud.firestore_v1.types.Value]: A 

230 dictionary of string keys and ``Value`` protobufs as dictionary 

231 values. 

232 """ 

233 return {key: encode_value(value) for key, value in values_dict.items()} 

234 

235 

236def document_snapshot_to_protobuf(snapshot: "google.cloud.firestore_v1.base_document.DocumentSnapshot") -> Optional["google.cloud.firestore_v1.types.Document"]: # type: ignore 

237 from google.cloud.firestore_v1.types import Document 

238 

239 if not snapshot.exists: 

240 return None 

241 

242 return Document( 

243 name=snapshot.reference._document_path, 

244 fields=encode_dict(snapshot._data), 

245 create_time=snapshot.create_time, 

246 update_time=snapshot.update_time, 

247 ) 

248 

249 

250class DocumentReferenceValue: 

251 """DocumentReference path container with accessors for each relevant chunk. 

252 

253 Usage: 

254 doc_ref_val = DocumentReferenceValue( 

255 'projects/my-proj/databases/(default)/documents/my-col/my-doc', 

256 ) 

257 assert doc_ref_val.project_name == 'my-proj' 

258 assert doc_ref_val.collection_name == 'my-col' 

259 assert doc_ref_val.document_id == 'my-doc' 

260 assert doc_ref_val.database_name == '(default)' 

261 

262 Raises: 

263 ValueError: If the supplied value cannot satisfy a complete path. 

264 """ 

265 

266 def __init__(self, reference_value: str): 

267 self._reference_value = reference_value 

268 

269 # The first 5 parts are 

270 # projects, {project}, databases, {database}, documents 

271 parts = reference_value.split(DOCUMENT_PATH_DELIMITER) 

272 if len(parts) < 7: 

273 msg = BAD_REFERENCE_ERROR.format(reference_value) 

274 raise ValueError(msg) 

275 

276 self.project_name = parts[1] 

277 self.collection_name = parts[5] 

278 self.database_name = parts[3] 

279 self.document_id = "/".join(parts[6:]) 

280 

281 @property 

282 def full_key(self) -> str: 

283 """Computed property for a DocumentReference's collection_name and 

284 document Id""" 

285 return "/".join([self.collection_name, self.document_id]) 

286 

287 @property 

288 def full_path(self) -> str: 

289 return self._reference_value or "/".join( 

290 [ 

291 "projects", 

292 self.project_name, 

293 "databases", 

294 self.database_name, 

295 "documents", 

296 self.collection_name, 

297 self.document_id, 

298 ] 

299 ) 

300 

301 

302def reference_value_to_document(reference_value, client) -> Any: 

303 """Convert a reference value string to a document. 

304 

305 Args: 

306 reference_value (str): A document reference value. 

307 client (:class:`~google.cloud.firestore_v1.client.Client`): 

308 A client that has a document factory. 

309 

310 Returns: 

311 :class:`~google.cloud.firestore_v1.document.DocumentReference`: 

312 The document corresponding to ``reference_value``. 

313 

314 Raises: 

315 ValueError: If the ``reference_value`` is not of the expected 

316 format: ``projects/{project}/databases/{database}/documents/...``. 

317 ValueError: If the ``reference_value`` does not come from the same 

318 project / database combination as the ``client``. 

319 """ 

320 from google.cloud.firestore_v1.base_document import BaseDocumentReference 

321 

322 doc_ref_value = DocumentReferenceValue(reference_value) 

323 

324 document: BaseDocumentReference = client.document(doc_ref_value.full_key) 

325 if document._document_path != reference_value: 

326 msg = WRONG_APP_REFERENCE.format(reference_value, client._database_string) 

327 raise ValueError(msg) 

328 

329 return document 

330 

331 

332def decode_value( 

333 value, client 

334) -> Union[None, bool, int, float, list, datetime.datetime, str, bytes, dict, GeoPoint]: 

335 """Converts a Firestore protobuf ``Value`` to a native Python value. 

336 

337 Args: 

338 value (google.cloud.firestore_v1.types.Value): A 

339 Firestore protobuf to be decoded / parsed / converted. 

340 client (:class:`~google.cloud.firestore_v1.client.Client`): 

341 A client that has a document factory. 

342 

343 Returns: 

344 Union[NoneType, bool, int, float, datetime.datetime, \ 

345 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]: A native 

346 Python value converted from the ``value``. 

347 

348 Raises: 

349 NotImplementedError: If the ``value_type`` is ``reference_value``. 

350 ValueError: If the ``value_type`` is unknown. 

351 """ 

352 value_pb = getattr(value, "_pb", value) 

353 value_type = value_pb.WhichOneof("value_type") 

354 

355 if value_type == "null_value": 

356 return None 

357 elif value_type == "boolean_value": 

358 return value_pb.boolean_value 

359 elif value_type == "integer_value": 

360 return value_pb.integer_value 

361 elif value_type == "double_value": 

362 return value_pb.double_value 

363 elif value_type == "timestamp_value": 

364 return DatetimeWithNanoseconds.from_timestamp_pb(value_pb.timestamp_value) 

365 elif value_type == "string_value": 

366 return value_pb.string_value 

367 elif value_type == "bytes_value": 

368 return value_pb.bytes_value 

369 elif value_type == "reference_value": 

370 return reference_value_to_document(value_pb.reference_value, client) 

371 elif value_type == "geo_point_value": 

372 return GeoPoint( 

373 value_pb.geo_point_value.latitude, value_pb.geo_point_value.longitude 

374 ) 

375 elif value_type == "array_value": 

376 return [ 

377 decode_value(element, client) for element in value_pb.array_value.values 

378 ] 

379 elif value_type == "map_value": 

380 return decode_dict(value_pb.map_value.fields, client) 

381 else: 

382 raise ValueError("Unknown ``value_type``", value_type) 

383 

384 

385def decode_dict(value_fields, client) -> dict: 

386 """Converts a protobuf map of Firestore ``Value``-s. 

387 

388 Args: 

389 value_fields (google.protobuf.pyext._message.MessageMapContainer): A 

390 protobuf map of Firestore ``Value``-s. 

391 client (:class:`~google.cloud.firestore_v1.client.Client`): 

392 A client that has a document factory. 

393 

394 Returns: 

395 Dict[str, Union[NoneType, bool, int, float, datetime.datetime, \ 

396 str, bytes, dict, ~google.cloud.Firestore.GeoPoint]]: A dictionary 

397 of native Python values converted from the ``value_fields``. 

398 """ 

399 value_fields_pb = getattr(value_fields, "_pb", value_fields) 

400 

401 return {key: decode_value(value, client) for key, value in value_fields_pb.items()} 

402 

403 

404def get_doc_id(document_pb, expected_prefix) -> str: 

405 """Parse a document ID from a document protobuf. 

406 

407 Args: 

408 document_pb (google.cloud.proto.firestore.v1.\ 

409 document.Document): A protobuf for a document that 

410 was created in a ``CreateDocument`` RPC. 

411 expected_prefix (str): The expected collection prefix for the 

412 fully-qualified document name. 

413 

414 Returns: 

415 str: The document ID from the protobuf. 

416 

417 Raises: 

418 ValueError: If the name does not begin with the prefix. 

419 """ 

420 prefix, document_id = document_pb.name.rsplit(DOCUMENT_PATH_DELIMITER, 1) 

421 if prefix != expected_prefix: 

422 raise ValueError( 

423 "Unexpected document name", 

424 document_pb.name, 

425 "Expected to begin with", 

426 expected_prefix, 

427 ) 

428 

429 return document_id 

430 

431 

432_EmptyDict = transforms.Sentinel("Marker for an empty dict value") 

433 

434 

435def extract_fields( 

436 document_data, prefix_path: FieldPath, expand_dots=False 

437) -> Generator[Tuple[Any, Any], Any, None]: 

438 """Do depth-first walk of tree, yielding field_path, value""" 

439 if not document_data: 

440 yield prefix_path, _EmptyDict 

441 else: 

442 for key, value in sorted(document_data.items()): 

443 if expand_dots: 

444 sub_key = FieldPath.from_string(key) 

445 else: 

446 sub_key = FieldPath(key) 

447 

448 field_path = FieldPath(*(prefix_path.parts + sub_key.parts)) 

449 

450 if isinstance(value, dict): 

451 for s_path, s_value in extract_fields(value, field_path): 

452 yield s_path, s_value 

453 else: 

454 yield field_path, value 

455 

456 

457def set_field_value(document_data, field_path, value) -> None: 

458 """Set a value into a document for a field_path""" 

459 current = document_data 

460 for element in field_path.parts[:-1]: 

461 current = current.setdefault(element, {}) 

462 if value is _EmptyDict: 

463 value = {} 

464 current[field_path.parts[-1]] = value 

465 

466 

467def get_field_value(document_data, field_path) -> Any: 

468 if not field_path.parts: 

469 raise ValueError("Empty path") 

470 

471 current = document_data 

472 for element in field_path.parts[:-1]: 

473 current = current[element] 

474 return current[field_path.parts[-1]] 

475 

476 

477class DocumentExtractor(object): 

478 """Break document data up into actual data and transforms. 

479 

480 Handle special values such as ``DELETE_FIELD``, ``SERVER_TIMESTAMP``. 

481 

482 Args: 

483 document_data (dict): 

484 Property names and values to use for sending a change to 

485 a document. 

486 """ 

487 

488 def __init__(self, document_data) -> None: 

489 self.document_data = document_data 

490 self.field_paths = [] 

491 self.deleted_fields = [] 

492 self.server_timestamps = [] 

493 self.array_removes = {} 

494 self.array_unions = {} 

495 self.increments = {} 

496 self.minimums = {} 

497 self.maximums = {} 

498 self.set_fields = {} 

499 self.empty_document = False 

500 

501 prefix_path = FieldPath() 

502 iterator = self._get_document_iterator(prefix_path) 

503 

504 for field_path, value in iterator: 

505 if field_path == prefix_path and value is _EmptyDict: 

506 self.empty_document = True 

507 

508 elif value is transforms.DELETE_FIELD: 

509 self.deleted_fields.append(field_path) 

510 

511 elif value is transforms.SERVER_TIMESTAMP: 

512 self.server_timestamps.append(field_path) 

513 

514 elif isinstance(value, transforms.ArrayRemove): 

515 self.array_removes[field_path] = value.values 

516 

517 elif isinstance(value, transforms.ArrayUnion): 

518 self.array_unions[field_path] = value.values 

519 

520 elif isinstance(value, transforms.Increment): 

521 self.increments[field_path] = value.value 

522 

523 elif isinstance(value, transforms.Maximum): 

524 self.maximums[field_path] = value.value 

525 

526 elif isinstance(value, transforms.Minimum): 

527 self.minimums[field_path] = value.value 

528 

529 else: 

530 self.field_paths.append(field_path) 

531 set_field_value(self.set_fields, field_path, value) 

532 

533 def _get_document_iterator( 

534 self, prefix_path: FieldPath 

535 ) -> Generator[Tuple[Any, Any], Any, None]: 

536 return extract_fields(self.document_data, prefix_path) 

537 

538 @property 

539 def has_transforms(self): 

540 return bool( 

541 self.server_timestamps 

542 or self.array_removes 

543 or self.array_unions 

544 or self.increments 

545 or self.maximums 

546 or self.minimums 

547 ) 

548 

549 @property 

550 def transform_paths(self): 

551 return sorted( 

552 self.server_timestamps 

553 + list(self.array_removes) 

554 + list(self.array_unions) 

555 + list(self.increments) 

556 + list(self.maximums) 

557 + list(self.minimums) 

558 ) 

559 

560 def _get_update_mask(self, allow_empty_mask=False) -> None: 

561 return None 

562 

563 def get_update_pb( 

564 self, document_path, exists=None, allow_empty_mask=False 

565 ) -> types.write.Write: 

566 if exists is not None: 

567 current_document = common.Precondition(exists=exists) 

568 else: 

569 current_document = None 

570 

571 update_pb = write.Write( 

572 update=document.Document( 

573 name=document_path, fields=encode_dict(self.set_fields) 

574 ), 

575 update_mask=self._get_update_mask(allow_empty_mask), 

576 current_document=current_document, 

577 ) 

578 

579 return update_pb 

580 

581 def get_field_transform_pbs( 

582 self, document_path 

583 ) -> List[types.write.DocumentTransform.FieldTransform]: 

584 def make_array_value(values): 

585 value_list = [encode_value(element) for element in values] 

586 return document.ArrayValue(values=value_list) 

587 

588 path_field_transforms = ( 

589 [ 

590 ( 

591 path, 

592 write.DocumentTransform.FieldTransform( 

593 field_path=path.to_api_repr(), 

594 set_to_server_value=REQUEST_TIME_ENUM, 

595 ), 

596 ) 

597 for path in self.server_timestamps 

598 ] 

599 + [ 

600 ( 

601 path, 

602 write.DocumentTransform.FieldTransform( 

603 field_path=path.to_api_repr(), 

604 remove_all_from_array=make_array_value(values), 

605 ), 

606 ) 

607 for path, values in self.array_removes.items() 

608 ] 

609 + [ 

610 ( 

611 path, 

612 write.DocumentTransform.FieldTransform( 

613 field_path=path.to_api_repr(), 

614 append_missing_elements=make_array_value(values), 

615 ), 

616 ) 

617 for path, values in self.array_unions.items() 

618 ] 

619 + [ 

620 ( 

621 path, 

622 write.DocumentTransform.FieldTransform( 

623 field_path=path.to_api_repr(), increment=encode_value(value) 

624 ), 

625 ) 

626 for path, value in self.increments.items() 

627 ] 

628 + [ 

629 ( 

630 path, 

631 write.DocumentTransform.FieldTransform( 

632 field_path=path.to_api_repr(), maximum=encode_value(value) 

633 ), 

634 ) 

635 for path, value in self.maximums.items() 

636 ] 

637 + [ 

638 ( 

639 path, 

640 write.DocumentTransform.FieldTransform( 

641 field_path=path.to_api_repr(), minimum=encode_value(value) 

642 ), 

643 ) 

644 for path, value in self.minimums.items() 

645 ] 

646 ) 

647 return [transform for path, transform in sorted(path_field_transforms)] 

648 

649 def get_transform_pb(self, document_path, exists=None) -> types.write.Write: 

650 field_transforms = self.get_field_transform_pbs(document_path) 

651 transform_pb = write.Write( 

652 transform=write.DocumentTransform( 

653 document=document_path, field_transforms=field_transforms 

654 ) 

655 ) 

656 if exists is not None: 

657 transform_pb._pb.current_document.CopyFrom( 

658 common.Precondition(exists=exists)._pb 

659 ) 

660 

661 return transform_pb 

662 

663 

664def pbs_for_create(document_path, document_data) -> List[types.write.Write]: 

665 """Make ``Write`` protobufs for ``create()`` methods. 

666 

667 Args: 

668 document_path (str): A fully-qualified document path. 

669 document_data (dict): Property names and values to use for 

670 creating a document. 

671 

672 Returns: 

673 List[google.cloud.firestore_v1.types.Write]: One or two 

674 ``Write`` protobuf instances for ``create()``. 

675 """ 

676 extractor = DocumentExtractor(document_data) 

677 

678 if extractor.deleted_fields: 

679 raise ValueError("Cannot apply DELETE_FIELD in a create request.") 

680 

681 create_pb = extractor.get_update_pb(document_path, exists=False) 

682 

683 if extractor.has_transforms: 

684 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

685 create_pb.update_transforms.extend(field_transform_pbs) 

686 

687 return [create_pb] 

688 

689 

690def pbs_for_set_no_merge(document_path, document_data) -> List[types.write.Write]: 

691 """Make ``Write`` protobufs for ``set()`` methods. 

692 

693 Args: 

694 document_path (str): A fully-qualified document path. 

695 document_data (dict): Property names and values to use for 

696 replacing a document. 

697 

698 Returns: 

699 List[google.cloud.firestore_v1.types.Write]: One 

700 or two ``Write`` protobuf instances for ``set()``. 

701 """ 

702 extractor = DocumentExtractor(document_data) 

703 

704 if extractor.deleted_fields: 

705 raise ValueError( 

706 "Cannot apply DELETE_FIELD in a set request without " 

707 "specifying 'merge=True' or 'merge=[field_paths]'." 

708 ) 

709 

710 set_pb = extractor.get_update_pb(document_path) 

711 

712 if extractor.has_transforms: 

713 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

714 set_pb.update_transforms.extend(field_transform_pbs) 

715 

716 return [set_pb] 

717 

718 

719class DocumentExtractorForMerge(DocumentExtractor): 

720 """Break document data up into actual data and transforms.""" 

721 

722 def __init__(self, document_data) -> None: 

723 super(DocumentExtractorForMerge, self).__init__(document_data) 

724 self.data_merge = [] 

725 self.transform_merge = [] 

726 self.merge = [] 

727 

728 def _apply_merge_all(self) -> None: 

729 self.data_merge = sorted(self.field_paths + self.deleted_fields) 

730 # TODO: other transforms 

731 self.transform_merge = self.transform_paths 

732 self.merge = sorted(self.data_merge + self.transform_paths) 

733 

734 def _construct_merge_paths(self, merge) -> Generator[Any, Any, None]: 

735 for merge_field in merge: 

736 if isinstance(merge_field, FieldPath): 

737 yield merge_field 

738 else: 

739 yield FieldPath(*parse_field_path(merge_field)) 

740 

741 def _normalize_merge_paths(self, merge) -> list: 

742 merge_paths = sorted(self._construct_merge_paths(merge)) 

743 

744 # Raise if any merge path is a parent of another. Leverage sorting 

745 # to avoid quadratic behavior. 

746 for index in range(len(merge_paths) - 1): 

747 lhs, rhs = merge_paths[index], merge_paths[index + 1] 

748 if lhs.eq_or_parent(rhs): 

749 raise ValueError("Merge paths overlap: {}, {}".format(lhs, rhs)) 

750 

751 for merge_path in merge_paths: 

752 if merge_path in self.deleted_fields: 

753 continue 

754 try: 

755 get_field_value(self.document_data, merge_path) 

756 except KeyError: 

757 raise ValueError("Invalid merge path: {}".format(merge_path)) 

758 

759 return merge_paths 

760 

761 def _apply_merge_paths(self, merge) -> None: 

762 if self.empty_document: 

763 raise ValueError("Cannot merge specific fields with empty document.") 

764 

765 merge_paths = self._normalize_merge_paths(merge) 

766 

767 del self.data_merge[:] 

768 del self.transform_merge[:] 

769 self.merge = merge_paths 

770 

771 for merge_path in merge_paths: 

772 if merge_path in self.transform_paths: 

773 self.transform_merge.append(merge_path) 

774 

775 for field_path in self.field_paths: 

776 if merge_path.eq_or_parent(field_path): 

777 self.data_merge.append(field_path) 

778 

779 # Clear out data for fields not merged. 

780 merged_set_fields = {} 

781 for field_path in self.data_merge: 

782 value = get_field_value(self.document_data, field_path) 

783 set_field_value(merged_set_fields, field_path, value) 

784 self.set_fields = merged_set_fields 

785 

786 unmerged_deleted_fields = [ 

787 field_path 

788 for field_path in self.deleted_fields 

789 if field_path not in self.merge 

790 ] 

791 if unmerged_deleted_fields: 

792 raise ValueError( 

793 "Cannot delete unmerged fields: {}".format(unmerged_deleted_fields) 

794 ) 

795 self.data_merge = sorted(self.data_merge + self.deleted_fields) 

796 

797 # Keep only transforms which are within merge. 

798 merged_transform_paths = set() 

799 for merge_path in self.merge: 

800 tranform_merge_paths = [ 

801 transform_path 

802 for transform_path in self.transform_paths 

803 if merge_path.eq_or_parent(transform_path) 

804 ] 

805 merged_transform_paths.update(tranform_merge_paths) 

806 

807 self.server_timestamps = [ 

808 path for path in self.server_timestamps if path in merged_transform_paths 

809 ] 

810 

811 self.array_removes = { 

812 path: values 

813 for path, values in self.array_removes.items() 

814 if path in merged_transform_paths 

815 } 

816 

817 self.array_unions = { 

818 path: values 

819 for path, values in self.array_unions.items() 

820 if path in merged_transform_paths 

821 } 

822 

823 def apply_merge(self, merge) -> None: 

824 if merge is True: # merge all fields 

825 self._apply_merge_all() 

826 else: 

827 self._apply_merge_paths(merge) 

828 

829 def _get_update_mask( 

830 self, allow_empty_mask=False 

831 ) -> Optional[types.common.DocumentMask]: 

832 # Mask uses dotted / quoted paths. 

833 mask_paths = [ 

834 field_path.to_api_repr() 

835 for field_path in self.merge 

836 if field_path not in self.transform_merge 

837 ] 

838 

839 return common.DocumentMask(field_paths=mask_paths) 

840 

841 

842def pbs_for_set_with_merge( 

843 document_path, document_data, merge 

844) -> List[types.write.Write]: 

845 """Make ``Write`` protobufs for ``set()`` methods. 

846 

847 Args: 

848 document_path (str): A fully-qualified document path. 

849 document_data (dict): Property names and values to use for 

850 replacing a document. 

851 merge (Optional[bool] or Optional[List<apispec>]): 

852 If True, merge all fields; else, merge only the named fields. 

853 

854 Returns: 

855 List[google.cloud.firestore_v1.types.Write]: One 

856 or two ``Write`` protobuf instances for ``set()``. 

857 """ 

858 extractor = DocumentExtractorForMerge(document_data) 

859 extractor.apply_merge(merge) 

860 

861 set_pb = extractor.get_update_pb(document_path) 

862 

863 if extractor.transform_paths: 

864 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

865 set_pb.update_transforms.extend(field_transform_pbs) 

866 

867 return [set_pb] 

868 

869 

870class DocumentExtractorForUpdate(DocumentExtractor): 

871 """Break document data up into actual data and transforms.""" 

872 

873 def __init__(self, document_data) -> None: 

874 super(DocumentExtractorForUpdate, self).__init__(document_data) 

875 self.top_level_paths = sorted( 

876 [FieldPath.from_string(key) for key in document_data] 

877 ) 

878 tops = set(self.top_level_paths) 

879 for top_level_path in self.top_level_paths: 

880 for ancestor in top_level_path.lineage(): 

881 if ancestor in tops: 

882 raise ValueError( 

883 "Conflicting field path: {}, {}".format( 

884 top_level_path, ancestor 

885 ) 

886 ) 

887 

888 for field_path in self.deleted_fields: 

889 if field_path not in tops: 

890 raise ValueError( 

891 "Cannot update with nest delete: {}".format(field_path) 

892 ) 

893 

894 def _get_document_iterator( 

895 self, prefix_path: FieldPath 

896 ) -> Generator[Tuple[Any, Any], Any, None]: 

897 return extract_fields(self.document_data, prefix_path, expand_dots=True) 

898 

899 def _get_update_mask(self, allow_empty_mask=False) -> types.common.DocumentMask: 

900 mask_paths = [] 

901 for field_path in self.top_level_paths: 

902 if field_path not in self.transform_paths: 

903 mask_paths.append(field_path.to_api_repr()) 

904 

905 return common.DocumentMask(field_paths=mask_paths) 

906 

907 

908def pbs_for_update(document_path, field_updates, option) -> List[types.write.Write]: 

909 """Make ``Write`` protobufs for ``update()`` methods. 

910 

911 Args: 

912 document_path (str): A fully-qualified document path. 

913 field_updates (dict): Field names or paths to update and values 

914 to update with. 

915 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): 

916 A write option to make assertions / preconditions on the server 

917 state of the document before applying changes. 

918 

919 Returns: 

920 List[google.cloud.firestore_v1.types.Write]: One 

921 or two ``Write`` protobuf instances for ``update()``. 

922 """ 

923 extractor = DocumentExtractorForUpdate(field_updates) 

924 

925 if extractor.empty_document: 

926 raise ValueError("Cannot update with an empty document.") 

927 

928 if option is None: # Default is to use ``exists=True``. 

929 option = ExistsOption(exists=True) 

930 

931 update_pb = extractor.get_update_pb(document_path) 

932 option.modify_write(update_pb) 

933 

934 if extractor.has_transforms: 

935 field_transform_pbs = extractor.get_field_transform_pbs(document_path) 

936 update_pb.update_transforms.extend(field_transform_pbs) 

937 

938 return [update_pb] 

939 

940 

941def pb_for_delete(document_path, option) -> types.write.Write: 

942 """Make a ``Write`` protobuf for ``delete()`` methods. 

943 

944 Args: 

945 document_path (str): A fully-qualified document path. 

946 option (optional[:class:`~google.cloud.firestore_v1.client.WriteOption`]): 

947 A write option to make assertions / preconditions on the server 

948 state of the document before applying changes. 

949 

950 Returns: 

951 google.cloud.firestore_v1.types.Write: A 

952 ``Write`` protobuf instance for the ``delete()``. 

953 """ 

954 write_pb = write.Write(delete=document_path) 

955 if option is not None: 

956 option.modify_write(write_pb) 

957 

958 return write_pb 

959 

960 

961class ReadAfterWriteError(Exception): 

962 """Raised when a read is attempted after a write. 

963 

964 Raised by "read" methods that use transactions. 

965 """ 

966 

967 

968def get_transaction_id(transaction, read_operation=True) -> Union[bytes, None]: 

969 """Get the transaction ID from a ``Transaction`` object. 

970 

971 Args: 

972 transaction (Optional[:class:`~google.cloud.firestore_v1.transaction.\ 

973 Transaction`]): 

974 An existing transaction that this query will run in. 

975 read_operation (Optional[bool]): Indicates if the transaction ID 

976 will be used in a read operation. Defaults to :data:`True`. 

977 

978 Returns: 

979 Optional[bytes]: The ID of the transaction, or :data:`None` if the 

980 ``transaction`` is :data:`None`. 

981 

982 Raises: 

983 ValueError: If the ``transaction`` is not in progress (only if 

984 ``transaction`` is not :data:`None`). 

985 ReadAfterWriteError: If the ``transaction`` has writes stored on 

986 it and ``read_operation`` is :data:`True`. 

987 """ 

988 if transaction is None: 

989 return None 

990 else: 

991 if not transaction.in_progress: 

992 raise ValueError(INACTIVE_TXN) 

993 if read_operation and len(transaction._write_pbs) > 0: 

994 raise ReadAfterWriteError(READ_AFTER_WRITE_ERROR) 

995 return transaction.id 

996 

997 

998def metadata_with_prefix(prefix: str, **kw) -> List[Tuple[str, str]]: 

999 """Create RPC metadata containing a prefix. 

1000 

1001 Args: 

1002 prefix (str): appropriate resource path. 

1003 

1004 Returns: 

1005 List[Tuple[str, str]]: RPC metadata with supplied prefix 

1006 """ 

1007 return [("google-cloud-resource-prefix", prefix)] 

1008 

1009 

1010class WriteOption(object): 

1011 """Option used to assert a condition on a write operation.""" 

1012 

1013 def modify_write(self, write, no_create_msg=None) -> NoReturn: 

1014 """Modify a ``Write`` protobuf based on the state of this write option. 

1015 

1016 This is a virtual method intended to be implemented by subclasses. 

1017 

1018 Args: 

1019 write (google.cloud.firestore_v1.types.Write): A 

1020 ``Write`` protobuf instance to be modified with a precondition 

1021 determined by the state of this option. 

1022 no_create_msg (Optional[str]): A message to use to indicate that 

1023 a create operation is not allowed. 

1024 

1025 Raises: 

1026 NotImplementedError: Always, this method is virtual. 

1027 """ 

1028 raise NotImplementedError 

1029 

1030 

1031class LastUpdateOption(WriteOption): 

1032 """Option used to assert a "last update" condition on a write operation. 

1033 

1034 This will typically be created by 

1035 :meth:`~google.cloud.firestore_v1.client.Client.write_option`. 

1036 

1037 Args: 

1038 last_update_time (google.protobuf.timestamp_pb2.Timestamp): A 

1039 timestamp. When set, the target document must exist and have 

1040 been last updated at that time. Protobuf ``update_time`` timestamps 

1041 are typically returned from methods that perform write operations 

1042 as part of a "write result" protobuf or directly. 

1043 """ 

1044 

1045 def __init__(self, last_update_time) -> None: 

1046 self._last_update_time = last_update_time 

1047 

1048 def __eq__(self, other): 

1049 if not isinstance(other, self.__class__): 

1050 return NotImplemented 

1051 return self._last_update_time == other._last_update_time 

1052 

1053 def modify_write(self, write, **unused_kwargs) -> None: 

1054 """Modify a ``Write`` protobuf based on the state of this write option. 

1055 

1056 The ``last_update_time`` is added to ``write_pb`` as an "update time" 

1057 precondition. When set, the target document must exist and have been 

1058 last updated at that time. 

1059 

1060 Args: 

1061 write_pb (google.cloud.firestore_v1.types.Write): A 

1062 ``Write`` protobuf instance to be modified with a precondition 

1063 determined by the state of this option. 

1064 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by 

1065 other subclasses that are unused here. 

1066 """ 

1067 current_doc = types.Precondition(update_time=self._last_update_time) 

1068 write._pb.current_document.CopyFrom(current_doc._pb) 

1069 

1070 

1071class ExistsOption(WriteOption): 

1072 """Option used to assert existence on a write operation. 

1073 

1074 This will typically be created by 

1075 :meth:`~google.cloud.firestore_v1.client.Client.write_option`. 

1076 

1077 Args: 

1078 exists (bool): Indicates if the document being modified 

1079 should already exist. 

1080 """ 

1081 

1082 def __init__(self, exists) -> None: 

1083 self._exists = exists 

1084 

1085 def __eq__(self, other): 

1086 if not isinstance(other, self.__class__): 

1087 return NotImplemented 

1088 return self._exists == other._exists 

1089 

1090 def modify_write(self, write, **unused_kwargs) -> None: 

1091 """Modify a ``Write`` protobuf based on the state of this write option. 

1092 

1093 If: 

1094 

1095 * ``exists=True``, adds a precondition that requires existence 

1096 * ``exists=False``, adds a precondition that requires non-existence 

1097 

1098 Args: 

1099 write (google.cloud.firestore_v1.types.Write): A 

1100 ``Write`` protobuf instance to be modified with a precondition 

1101 determined by the state of this option. 

1102 unused_kwargs (Dict[str, Any]): Keyword arguments accepted by 

1103 other subclasses that are unused here. 

1104 """ 

1105 current_doc = types.Precondition(exists=self._exists) 

1106 write._pb.current_document.CopyFrom(current_doc._pb) 

1107 

1108 

1109def make_retry_timeout_kwargs(retry, timeout) -> dict: 

1110 """Helper fo API methods which take optional 'retry' / 'timeout' args.""" 

1111 kwargs = {} 

1112 

1113 if retry is not gapic_v1.method.DEFAULT: 

1114 kwargs["retry"] = retry 

1115 

1116 if timeout is not None: 

1117 kwargs["timeout"] = timeout 

1118 

1119 return kwargs 

1120 

1121 

1122def build_timestamp( 

1123 dt: Optional[Union[DatetimeWithNanoseconds, datetime.datetime]] = None 

1124) -> Timestamp: 

1125 """Returns the supplied datetime (or "now") as a Timestamp""" 

1126 return _datetime_to_pb_timestamp( 

1127 dt or DatetimeWithNanoseconds.now(tz=datetime.timezone.utc) 

1128 ) 

1129 

1130 

1131def compare_timestamps( 

1132 ts1: Union[Timestamp, datetime.datetime], 

1133 ts2: Union[Timestamp, datetime.datetime], 

1134) -> int: 

1135 ts1 = build_timestamp(ts1) if not isinstance(ts1, Timestamp) else ts1 

1136 ts2 = build_timestamp(ts2) if not isinstance(ts2, Timestamp) else ts2 

1137 ts1_nanos = ts1.nanos + ts1.seconds * 1e9 

1138 ts2_nanos = ts2.nanos + ts2.seconds * 1e9 

1139 if ts1_nanos == ts2_nanos: 

1140 return 0 

1141 return 1 if ts1_nanos > ts2_nanos else -1 

1142 

1143 

1144def deserialize_bundle( 

1145 serialized: Union[str, bytes], 

1146 client: "google.cloud.firestore_v1.client.BaseClient", # type: ignore 

1147) -> "google.cloud.firestore_bundle.FirestoreBundle": # type: ignore 

1148 """Inverse operation to a `FirestoreBundle` instance's `build()` method. 

1149 

1150 Args: 

1151 serialized (Union[str, bytes]): The result of `FirestoreBundle.build()`. 

1152 Should be a list of dictionaries in string format. 

1153 client (BaseClient): A connected Client instance. 

1154 

1155 Returns: 

1156 FirestoreBundle: A bundle equivalent to that which called `build()` and 

1157 initially created the `serialized` value. 

1158 

1159 Raises: 

1160 ValueError: If any of the dictionaries in the list contain any more than 

1161 one top-level key. 

1162 ValueError: If any unexpected BundleElement types are encountered. 

1163 ValueError: If the serialized bundle ends before expected. 

1164 """ 

1165 from google.cloud.firestore_bundle import BundleElement, FirestoreBundle 

1166 

1167 # Outlines the legal transitions from one BundleElement to another. 

1168 bundle_state_machine = { 

1169 "__initial__": ["metadata"], 

1170 "metadata": ["namedQuery", "documentMetadata", "__end__"], 

1171 "namedQuery": ["namedQuery", "documentMetadata", "__end__"], 

1172 "documentMetadata": ["document"], 

1173 "document": ["documentMetadata", "__end__"], 

1174 } 

1175 allowed_next_element_types: List[str] = bundle_state_machine["__initial__"] 

1176 

1177 # This must be saved and added last, since we cache it to preserve timestamps, 

1178 # yet must flush it whenever a new document or query is added to a bundle. 

1179 # The process of deserializing a bundle uses these methods which flush a 

1180 # cached metadata element, and thus, it must be the last BundleElement 

1181 # added during deserialization. 

1182 metadata_bundle_element: Optional[BundleElement] = None 

1183 

1184 bundle: Optional[FirestoreBundle] = None 

1185 data: Dict 

1186 for data in _parse_bundle_elements_data(serialized): 

1187 # BundleElements are serialized as JSON containing one key outlining 

1188 # the type, with all further data nested under that key 

1189 keys: List[str] = list(data.keys()) 

1190 

1191 if len(keys) != 1: 

1192 raise ValueError("Expected serialized BundleElement with one top-level key") 

1193 

1194 key: str = keys[0] 

1195 

1196 if key not in allowed_next_element_types: 

1197 raise ValueError( 

1198 f"Encountered BundleElement of type {key}. " 

1199 f"Expected one of {allowed_next_element_types}" 

1200 ) 

1201 

1202 # Create and add our BundleElement 

1203 bundle_element: BundleElement 

1204 try: 

1205 bundle_element: BundleElement = BundleElement.from_json(json.dumps(data)) # type: ignore 

1206 except AttributeError as e: 

1207 # Some bad serialization formats cannot be universally deserialized. 

1208 if e.args[0] == "'dict' object has no attribute 'find'": # pragma: NO COVER 

1209 raise ValueError( 

1210 "Invalid serialization of datetimes. " 

1211 "Cannot deserialize Bundles created from the NodeJS SDK." 

1212 ) 

1213 raise e # pragma: NO COVER 

1214 

1215 if bundle is None: 

1216 # This must be the first bundle type encountered 

1217 assert key == "metadata" 

1218 bundle = FirestoreBundle(data[key]["id"]) 

1219 metadata_bundle_element = bundle_element 

1220 

1221 else: 

1222 bundle._add_bundle_element(bundle_element, client=client, type=key) 

1223 

1224 # Update the allowed next BundleElement types 

1225 allowed_next_element_types = bundle_state_machine[key] 

1226 

1227 if "__end__" not in allowed_next_element_types: 

1228 raise ValueError("Unexpected end to serialized FirestoreBundle") 

1229 

1230 # Now, finally add the metadata element 

1231 bundle._add_bundle_element( 

1232 metadata_bundle_element, 

1233 client=client, 

1234 type="metadata", # type: ignore 

1235 ) 

1236 

1237 return bundle 

1238 

1239 

1240def _parse_bundle_elements_data(serialized: Union[str, bytes]) -> Generator[Dict, None, None]: # type: ignore 

1241 """Reads through a serialized FirestoreBundle and yields JSON chunks that 

1242 were created via `BundleElement.to_json(bundle_element)`. 

1243 

1244 Serialized FirestoreBundle instances are length-prefixed JSON objects, and 

1245 so are of the form "123{...}57{...}" 

1246 To correctly and safely read a bundle, we must first detect these length 

1247 prefixes, read that many bytes of data, and attempt to JSON-parse that. 

1248 

1249 Raises: 

1250 ValueError: If a chunk of JSON ever starts without following a length 

1251 prefix. 

1252 """ 

1253 _serialized: Iterator[int] = iter( 

1254 serialized if isinstance(serialized, bytes) else serialized.encode("utf-8") 

1255 ) 

1256 

1257 length_prefix: str = "" 

1258 while True: 

1259 byte: Optional[int] = next(_serialized, None) 

1260 

1261 if byte is None: 

1262 return None 

1263 

1264 _str: str = chr(byte) 

1265 if _str.isnumeric(): 

1266 length_prefix += _str 

1267 else: 

1268 if length_prefix == "": 

1269 raise ValueError("Expected length prefix") 

1270 

1271 _length_prefix = int(length_prefix) 

1272 length_prefix = "" 

1273 _bytes = bytearray([byte]) 

1274 _counter = 1 

1275 while _counter < _length_prefix: 

1276 _bytes.append(next(_serialized)) 

1277 _counter += 1 

1278 

1279 yield json.loads(_bytes.decode("utf-8")) 

1280 

1281 

1282def _get_documents_from_bundle( 

1283 bundle, *, query_name: Optional[str] = None 

1284) -> Generator["google.cloud.firestore.DocumentSnapshot", None, None]: # type: ignore 

1285 from google.cloud.firestore_bundle.bundle import _BundledDocument 

1286 

1287 bundled_doc: _BundledDocument 

1288 for bundled_doc in bundle.documents.values(): 

1289 if query_name and query_name not in bundled_doc.metadata.queries: 

1290 continue 

1291 yield bundled_doc.snapshot 

1292 

1293 

1294def _get_document_from_bundle( 

1295 bundle, 

1296 *, 

1297 document_id: str, 

1298) -> Optional["google.cloud.firestore.DocumentSnapshot"]: # type: ignore 

1299 bundled_doc = bundle.documents.get(document_id) 

1300 if bundled_doc: 

1301 return bundled_doc.snapshot