1# Copyright 2017 Google LLC All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14from __future__ import annotations
15
16import collections
17import functools
18import logging
19import threading
20from enum import Enum
21
22import grpc # type: ignore
23from google.api_core import exceptions
24from google.api_core.bidi import BackgroundConsumer, ResumableBidiRpc
25
26from google.cloud.firestore_v1 import _helpers
27from google.cloud.firestore_v1.types.firestore import (
28 ListenRequest,
29 Target,
30 TargetChange,
31)
32
33TargetChangeType = TargetChange.TargetChangeType
34
35_LOGGER = logging.getLogger(__name__)
36
37WATCH_TARGET_ID = 0x5079 # "Py"
38
39GRPC_STATUS_CODE = {
40 "OK": 0,
41 "CANCELLED": 1,
42 "UNKNOWN": 2,
43 "INVALID_ARGUMENT": 3,
44 "DEADLINE_EXCEEDED": 4,
45 "NOT_FOUND": 5,
46 "ALREADY_EXISTS": 6,
47 "PERMISSION_DENIED": 7,
48 "UNAUTHENTICATED": 16,
49 "RESOURCE_EXHAUSTED": 8,
50 "FAILED_PRECONDITION": 9,
51 "ABORTED": 10,
52 "OUT_OF_RANGE": 11,
53 "UNIMPLEMENTED": 12,
54 "INTERNAL": 13,
55 "UNAVAILABLE": 14,
56 "DATA_LOSS": 15,
57 "DO_NOT_USE": -1,
58}
59_RPC_ERROR_THREAD_NAME = "Thread-OnRpcTerminated"
60_RECOVERABLE_STREAM_EXCEPTIONS = (
61 exceptions.Aborted,
62 exceptions.Cancelled,
63 exceptions.Unknown,
64 exceptions.DeadlineExceeded,
65 exceptions.ResourceExhausted,
66 exceptions.InternalServerError,
67 exceptions.ServiceUnavailable,
68 exceptions.Unauthenticated,
69)
70_TERMINATING_STREAM_EXCEPTIONS = (exceptions.Cancelled,)
71
72DocTreeEntry = collections.namedtuple("DocTreeEntry", ["value", "index"])
73
74
75class WatchDocTree(object):
76 # TODO: Currently this uses a dict. Other implementations use a rbtree.
77 # The performance of this implementation should be investigated and may
78 # require modifying the underlying datastructure to a rbtree.
79 def __init__(self):
80 self._dict = {}
81 self._index = 0
82
83 def keys(self):
84 return list(self._dict.keys())
85
86 def _copy(self):
87 wdt = WatchDocTree()
88 wdt._dict = self._dict.copy()
89 wdt._index = self._index
90 self = wdt
91 return self
92
93 def insert(self, key, value):
94 self = self._copy()
95 self._dict[key] = DocTreeEntry(value, self._index)
96 self._index += 1
97 return self
98
99 def find(self, key):
100 return self._dict[key]
101
102 def remove(self, key):
103 self = self._copy()
104 del self._dict[key]
105 return self
106
107 def __iter__(self):
108 for k in self._dict:
109 yield k
110
111 def __len__(self):
112 return len(self._dict)
113
114 def __contains__(self, k):
115 return k in self._dict
116
117
118class ChangeType(Enum):
119 ADDED = 1
120 REMOVED = 2
121 MODIFIED = 3
122
123
124class DocumentChange(object):
125 def __init__(self, type, document, old_index, new_index):
126 """DocumentChange
127
128 Args:
129 type (ChangeType):
130 document (document.DocumentSnapshot):
131 old_index (int):
132 new_index (int):
133 """
134 # TODO: spec indicated an isEqual param also
135 self.type = type
136 self.document = document
137 self.old_index = old_index
138 self.new_index = new_index
139
140
141class WatchResult(object):
142 def __init__(self, snapshot, name, change_type):
143 self.snapshot = snapshot
144 self.name = name
145 self.change_type = change_type
146
147
148def _maybe_wrap_exception(exception):
149 """Wraps a gRPC exception class, if needed."""
150 if isinstance(exception, grpc.RpcError):
151 return exceptions.from_grpc_error(exception)
152 return exception
153
154
155def document_watch_comparator(doc1, doc2):
156 assert doc1 == doc2, "Document watches only support one document."
157 return 0
158
159
160def _should_recover(exception):
161 wrapped = _maybe_wrap_exception(exception)
162 return isinstance(wrapped, _RECOVERABLE_STREAM_EXCEPTIONS)
163
164
165def _should_terminate(exception):
166 wrapped = _maybe_wrap_exception(exception)
167 return isinstance(wrapped, _TERMINATING_STREAM_EXCEPTIONS)
168
169
170class Watch(object):
171 def __init__(
172 self,
173 document_reference,
174 firestore,
175 target,
176 comparator,
177 snapshot_callback,
178 document_snapshot_cls,
179 ):
180 """
181 Args:
182 firestore:
183 target:
184 comparator:
185 snapshot_callback: Callback method to process snapshots.
186 Args:
187 docs (List(DocumentSnapshot)): A callback that returns the
188 ordered list of documents stored in this snapshot.
189 changes (List(str)): A callback that returns the list of
190 changed documents since the last snapshot delivered for
191 this watch.
192 read_time (string): The ISO 8601 time at which this
193 snapshot was obtained.
194
195 document_snapshot_cls: factory for instances of DocumentSnapshot
196 """
197 self._document_reference = document_reference
198 self._firestore = firestore
199 self._targets = target
200 self._comparator = comparator
201 self._document_snapshot_cls = document_snapshot_cls
202 self._snapshot_callback = snapshot_callback
203 self._api = firestore._firestore_api
204 self._closing = threading.Lock()
205 self._closed = False
206 self._set_documents_pfx(firestore._database_string)
207
208 self.resume_token = None
209
210 # Initialize state for on_snapshot
211 # The sorted tree of QueryDocumentSnapshots as sent in the last
212 # snapshot. We only look at the keys.
213 self.doc_tree = WatchDocTree()
214
215 # A map of document names to QueryDocumentSnapshots for the last sent
216 # snapshot.
217 self.doc_map = {}
218
219 # The accumulates map of document changes (keyed by document name) for
220 # the current snapshot.
221 self.change_map = {}
222
223 # The current state of the query results.
224 self.current = False
225
226 # We need this to track whether we've pushed an initial set of changes,
227 # since we should push those even when there are no changes, if there
228 # aren't docs.
229 self.has_pushed = False
230
231 self._init_stream()
232
233 def _init_stream(self):
234 rpc_request = self._get_rpc_request
235
236 self._rpc: ResumableBidiRpc | None = ResumableBidiRpc(
237 start_rpc=self._api._transport.listen,
238 should_recover=_should_recover,
239 should_terminate=_should_terminate,
240 initial_request=rpc_request,
241 metadata=self._firestore._rpc_metadata,
242 )
243
244 self._rpc.add_done_callback(self._on_rpc_done)
245
246 # The server assigns and updates the resume token.
247 self._consumer: BackgroundConsumer | None = BackgroundConsumer(
248 self._rpc, self.on_snapshot
249 )
250 self._consumer.start()
251
252 @classmethod
253 def for_document(
254 cls,
255 document_ref,
256 snapshot_callback,
257 document_snapshot_cls,
258 ):
259 """
260 Creates a watch snapshot listener for a document. snapshot_callback
261 receives a DocumentChange object, but may also start to get
262 targetChange and such soon
263
264 Args:
265 document_ref: Reference to Document
266 snapshot_callback: callback to be called on snapshot
267 document_snapshot_cls: class to make snapshots with
268 reference_class_instance: class make references
269
270 """
271 return cls(
272 document_ref,
273 document_ref._client,
274 {
275 "documents": {"documents": [document_ref._document_path]},
276 "target_id": WATCH_TARGET_ID,
277 },
278 document_watch_comparator,
279 snapshot_callback,
280 document_snapshot_cls,
281 )
282
283 @classmethod
284 def for_query(cls, query, snapshot_callback, document_snapshot_cls):
285 parent_path, _ = query._parent._parent_info()
286 query_target = Target.QueryTarget(
287 parent=parent_path, structured_query=query._to_protobuf()
288 )
289
290 return cls(
291 query,
292 query._client,
293 {"query": query_target._pb, "target_id": WATCH_TARGET_ID},
294 query._comparator,
295 snapshot_callback,
296 document_snapshot_cls,
297 )
298
299 def _get_rpc_request(self):
300 if self.resume_token is not None:
301 self._targets["resume_token"] = self.resume_token
302 else:
303 self._targets.pop("resume_token", None)
304
305 return ListenRequest(
306 database=self._firestore._database_string, add_target=self._targets
307 )
308
309 def _set_documents_pfx(self, database_string):
310 self._documents_pfx = f"{database_string}/documents/"
311 self._documents_pfx_len = len(self._documents_pfx)
312
313 @property
314 def is_active(self):
315 """bool: True if this manager is actively streaming.
316
317 Note that ``False`` does not indicate this is complete shut down,
318 just that it stopped getting new messages.
319 """
320 return self._consumer is not None and self._consumer.is_active
321
322 def close(self, reason=None):
323 """Stop consuming messages and shutdown all helper threads.
324
325 This method is idempotent. Additional calls will have no effect.
326
327 Args:
328 reason (Any): The reason to close this. If None, this is considered
329 an "intentional" shutdown.
330 """
331 with self._closing:
332 if self._closed:
333 return
334
335 # Stop consuming messages.
336 if self._consumer:
337 if self.is_active:
338 _LOGGER.debug("Stopping consumer.")
339 self._consumer.stop()
340 self._consumer._on_response = None
341 self._consumer = None
342
343 self._snapshot_callback = None
344 if self._rpc:
345 self._rpc.close()
346 self._rpc._initial_request = None
347 self._rpc._callbacks = []
348 self._rpc = None
349 self._closed = True
350 _LOGGER.debug("Finished stopping manager.")
351
352 if reason:
353 # Raise an exception if a reason is provided
354 _LOGGER.debug("reason for closing: %s" % reason)
355 if isinstance(reason, Exception):
356 raise reason
357 raise RuntimeError(reason)
358
359 def _on_rpc_done(self, future):
360 """Triggered whenever the underlying RPC terminates without recovery.
361
362 This is typically triggered from one of two threads: the background
363 consumer thread (when calling ``recv()`` produces a non-recoverable
364 error) or the grpc management thread (when cancelling the RPC).
365
366 This method is *non-blocking*. It will start another thread to deal
367 with shutting everything down. This is to prevent blocking in the
368 background consumer and preventing it from being ``joined()``.
369 """
370 _LOGGER.info("RPC termination has signaled manager shutdown.")
371 future = _maybe_wrap_exception(future)
372 thread = threading.Thread(
373 name=_RPC_ERROR_THREAD_NAME, target=self.close, kwargs={"reason": future}
374 )
375 thread.daemon = True
376 thread.start()
377
378 def unsubscribe(self):
379 self.close()
380
381 def _on_snapshot_target_change_no_change(self, target_change):
382 _LOGGER.debug("on_snapshot: target change: NO_CHANGE")
383
384 no_target_ids = (
385 target_change.target_ids is None or len(target_change.target_ids) == 0
386 )
387 if no_target_ids and target_change.read_time and self.current:
388 # TargetChange.TargetChangeType.CURRENT followed by
389 # TargetChange.TargetChangeType.NO_CHANGE
390 # signals a consistent state. Invoke the onSnapshot
391 # callback as specified by the user.
392 self.push(target_change.read_time, target_change.resume_token)
393
394 def _on_snapshot_target_change_add(self, target_change):
395 _LOGGER.debug("on_snapshot: target change: ADD")
396 target_id = target_change.target_ids[0]
397 if target_id != WATCH_TARGET_ID:
398 raise RuntimeError("Unexpected target ID %s sent by server" % target_id)
399
400 def _on_snapshot_target_change_remove(self, target_change):
401 _LOGGER.debug("on_snapshot: target change: REMOVE")
402
403 if target_change.cause.code:
404 code = target_change.cause.code
405 message = target_change.cause.message
406 else:
407 code = 13
408 message = "internal error"
409
410 error_message = "Error %s: %s" % (code, message)
411
412 raise RuntimeError(error_message) from exceptions.from_grpc_status(
413 code, message
414 )
415
416 def _on_snapshot_target_change_reset(self, target_change):
417 # Whatever changes have happened so far no longer matter.
418 _LOGGER.debug("on_snapshot: target change: RESET")
419 self._reset_docs()
420
421 def _on_snapshot_target_change_current(self, target_change):
422 _LOGGER.debug("on_snapshot: target change: CURRENT")
423 self.current = True
424
425 _target_changetype_dispatch = {
426 TargetChangeType.NO_CHANGE: _on_snapshot_target_change_no_change,
427 TargetChangeType.ADD: _on_snapshot_target_change_add,
428 TargetChangeType.REMOVE: _on_snapshot_target_change_remove,
429 TargetChangeType.RESET: _on_snapshot_target_change_reset,
430 TargetChangeType.CURRENT: _on_snapshot_target_change_current,
431 }
432
433 def _strip_document_pfx(self, document_name):
434 if document_name.startswith(self._documents_pfx):
435 document_name = document_name[self._documents_pfx_len :]
436 return document_name
437
438 def on_snapshot(self, proto):
439 """Process a response from the bi-directional gRPC stream.
440
441 Collect changes and push the changes in a batch to the customer
442 when we receive 'current' from the listen response.
443
444 Args:
445 proto(`google.cloud.firestore_v1.types.ListenResponse`):
446 Callback method that receives a object to
447 """
448 if self._closing.locked():
449 # don't process on_snapshot responses while spinning down, to prevent deadlock
450 return
451 if proto is None:
452 self.close()
453 return
454
455 pb = proto._pb
456 which = pb.WhichOneof("response_type")
457
458 if which == "target_change":
459 target_change_type = pb.target_change.target_change_type
460 _LOGGER.debug(f"on_snapshot: target change: {target_change_type}")
461
462 meth = self._target_changetype_dispatch.get(target_change_type)
463
464 if meth is None:
465 message = f"Unknown target change type: {target_change_type}"
466 _LOGGER.info(f"on_snapshot: {message}")
467 self.close(reason=ValueError(message))
468 else:
469 try:
470 # Use 'proto' vs 'pb' for datetime handling
471 meth(self, proto.target_change)
472 except Exception as exc2:
473 _LOGGER.debug(f"meth(proto) exc: {exc2}")
474 raise
475
476 # NOTE:
477 # in other implementations, such as node, the backoff is reset here
478 # in this version bidi rpc is just used and will control this.
479
480 elif which == "document_change":
481 _LOGGER.debug("on_snapshot: document change")
482
483 # No other target_ids can show up here, but we still need to see
484 # if the targetId was in the added list or removed list.
485 changed = WATCH_TARGET_ID in pb.document_change.target_ids
486 removed = WATCH_TARGET_ID in pb.document_change.removed_target_ids
487
488 # google.cloud.firestore_v1.types.Document
489 # Use 'proto' vs 'pb' for datetime handling
490 document = proto.document_change.document
491
492 if changed:
493 _LOGGER.debug("on_snapshot: document change: CHANGED")
494
495 data = _helpers.decode_dict(document.fields, self._firestore)
496
497 # Create a snapshot. As Document and Query objects can be
498 # passed we need to get a Document Reference in a more manual
499 # fashion than self._document_reference
500 document_name = self._strip_document_pfx(document.name)
501 document_ref = self._firestore.document(document_name)
502
503 snapshot = self._document_snapshot_cls(
504 reference=document_ref,
505 data=data,
506 exists=True,
507 read_time=None,
508 create_time=document.create_time,
509 update_time=document.update_time,
510 )
511 self.change_map[document.name] = snapshot
512
513 elif removed:
514 _LOGGER.debug("on_snapshot: document change: REMOVED")
515 self.change_map[document.name] = ChangeType.REMOVED
516
517 # NB: document_delete and document_remove (as far as we, the client,
518 # are concerned) are functionally equivalent
519
520 elif which == "document_delete":
521 _LOGGER.debug("on_snapshot: document change: DELETE")
522 name = pb.document_delete.document
523 self.change_map[name] = ChangeType.REMOVED
524
525 elif which == "document_remove":
526 _LOGGER.debug("on_snapshot: document change: REMOVE")
527 name = pb.document_remove.document
528 self.change_map[name] = ChangeType.REMOVED
529
530 elif which == "filter":
531 _LOGGER.debug("on_snapshot: filter update")
532 if pb.filter.count != self._current_size():
533 # First, shut down current stream
534 _LOGGER.info("Filter mismatch -- restarting stream.")
535 thread = threading.Thread(
536 name=_RPC_ERROR_THREAD_NAME,
537 target=self.close,
538 )
539 thread.start()
540 thread.join() # wait for shutdown to complete
541 # Then, remove all the current results.
542 self._reset_docs()
543 # Finally, restart stream.
544 self._init_stream()
545
546 else:
547 _LOGGER.debug("UNKNOWN TYPE. UHOH")
548 message = f"Unknown listen response type: {proto}"
549 self.close(reason=ValueError(message))
550
551 def push(self, read_time, next_resume_token):
552 """Invoke the callback with a new snapshot
553
554 Build the sntapshot from the current set of changes.
555
556 Clear the current changes on completion.
557 """
558 deletes, adds, updates = self._extract_changes(
559 self.doc_map, self.change_map, read_time
560 )
561
562 updated_tree, updated_map, appliedChanges = self._compute_snapshot(
563 self.doc_tree, self.doc_map, deletes, adds, updates
564 )
565
566 if not self.has_pushed or len(appliedChanges):
567 # TODO: It is possible in the future we will have the tree order
568 # on insert. For now, we sort here.
569 key = functools.cmp_to_key(self._comparator)
570 keys = sorted(updated_tree.keys(), key=key)
571
572 self._snapshot_callback(keys, appliedChanges, read_time)
573 self.has_pushed = True
574
575 self.doc_tree = updated_tree
576 self.doc_map = updated_map
577 self.change_map.clear()
578 self.resume_token = next_resume_token
579
580 @staticmethod
581 def _extract_changes(doc_map, changes, read_time):
582 deletes = []
583 adds = []
584 updates = []
585
586 for name, value in changes.items():
587 if value == ChangeType.REMOVED:
588 if name in doc_map:
589 deletes.append(name)
590 elif name in doc_map:
591 if read_time is not None:
592 value.read_time = read_time
593 updates.append(value)
594 else:
595 if read_time is not None:
596 value.read_time = read_time
597 adds.append(value)
598
599 return (deletes, adds, updates)
600
601 def _compute_snapshot(
602 self, doc_tree, doc_map, delete_changes, add_changes, update_changes
603 ):
604 updated_tree = doc_tree
605 updated_map = doc_map
606
607 assert len(doc_tree) == len(doc_map), (
608 "The document tree and document map should have the same "
609 + "number of entries."
610 )
611
612 def delete_doc(name, updated_tree, updated_map):
613 """
614 Applies a document delete to the document tree and document map.
615 Returns the corresponding DocumentChange event.
616 """
617 assert name in updated_map, "Document to delete does not exist"
618 old_document = updated_map.get(name)
619 # TODO: If a document doesn't exist this raises IndexError. Handle?
620 existing = updated_tree.find(old_document)
621 old_index = existing.index
622 updated_tree = updated_tree.remove(old_document)
623 del updated_map[name]
624 return (
625 DocumentChange(ChangeType.REMOVED, old_document, old_index, -1),
626 updated_tree,
627 updated_map,
628 )
629
630 def add_doc(new_document, updated_tree, updated_map):
631 """
632 Applies a document add to the document tree and the document map.
633 Returns the corresponding DocumentChange event.
634 """
635 name = new_document.reference._document_path
636 assert name not in updated_map, "Document to add already exists"
637 updated_tree = updated_tree.insert(new_document, None)
638 new_index = updated_tree.find(new_document).index
639 updated_map[name] = new_document
640 return (
641 DocumentChange(ChangeType.ADDED, new_document, -1, new_index),
642 updated_tree,
643 updated_map,
644 )
645
646 def modify_doc(new_document, updated_tree, updated_map):
647 """
648 Applies a document modification to the document tree and the
649 document map.
650 Returns the DocumentChange event for successful modifications.
651 """
652 name = new_document.reference._document_path
653 assert name in updated_map, "Document to modify does not exist"
654 old_document = updated_map.get(name)
655 if old_document.update_time != new_document.update_time:
656 remove_change, updated_tree, updated_map = delete_doc(
657 name, updated_tree, updated_map
658 )
659 add_change, updated_tree, updated_map = add_doc(
660 new_document, updated_tree, updated_map
661 )
662 return (
663 DocumentChange(
664 ChangeType.MODIFIED,
665 new_document,
666 remove_change.old_index,
667 add_change.new_index,
668 ),
669 updated_tree,
670 updated_map,
671 )
672
673 return None, updated_tree, updated_map
674
675 # Process the sorted changes in the order that is expected by our
676 # clients (removals, additions, and then modifications). We also need
677 # to sort the individual changes to assure that old_index/new_index
678 # keep incrementing.
679 appliedChanges = []
680
681 key = functools.cmp_to_key(self._comparator)
682
683 # Deletes are sorted based on the order of the existing document.
684 delete_changes = sorted(delete_changes)
685 for name in delete_changes:
686 change, updated_tree, updated_map = delete_doc(
687 name, updated_tree, updated_map
688 )
689 appliedChanges.append(change)
690
691 add_changes = sorted(add_changes, key=key)
692 _LOGGER.debug("walk over add_changes")
693 for snapshot in add_changes:
694 _LOGGER.debug("in add_changes")
695 change, updated_tree, updated_map = add_doc(
696 snapshot, updated_tree, updated_map
697 )
698 appliedChanges.append(change)
699
700 update_changes = sorted(update_changes, key=key)
701 for snapshot in update_changes:
702 change, updated_tree, updated_map = modify_doc(
703 snapshot, updated_tree, updated_map
704 )
705 if change is not None:
706 appliedChanges.append(change)
707
708 assert len(updated_tree) == len(updated_map), (
709 "The update document tree and document map "
710 "should have the same number of entries."
711 )
712 return (updated_tree, updated_map, appliedChanges)
713
714 def _current_size(self):
715 """Return the current count of all documents.
716
717 Count includes the changes from the current changeMap.
718 """
719 deletes, adds, _ = self._extract_changes(self.doc_map, self.change_map, None)
720 return len(self.doc_map) + len(adds) - len(deletes)
721
722 def _reset_docs(self):
723 """
724 Helper to clear the docs on RESET or filter mismatch.
725 """
726 _LOGGER.debug("resetting documents")
727 self.change_map.clear()
728 self.resume_token = None
729
730 # Mark each document as deleted. If documents are not deleted
731 # they will be sent again by the server.
732 for snapshot in self.doc_tree.keys():
733 name = snapshot.reference._document_path
734 self.change_map[name] = ChangeType.REMOVED
735
736 self.current = False