1# Copyright 2018, Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Schedulers provide means to *schedule* callbacks asynchronously.
16
17These are used by the subscriber to call the user-provided callback to process
18each message.
19"""
20
21import abc
22import concurrent.futures
23import queue
24import sys
25import typing
26from typing import Callable, List, Optional
27import warnings
28
29if typing.TYPE_CHECKING: # pragma: NO COVER
30 from google.cloud import pubsub_v1
31
32
33class Scheduler(metaclass=abc.ABCMeta):
34 """Abstract base class for schedulers.
35
36 Schedulers are used to schedule callbacks asynchronously.
37 """
38
39 @property
40 @abc.abstractmethod
41 def queue(self) -> "queue.Queue": # pragma: NO COVER
42 """Queue: A concurrency-safe queue specific to the underlying
43 concurrency implementation.
44
45 This queue is used to send messages *back* to the scheduling actor.
46 """
47 raise NotImplementedError
48
49 @abc.abstractmethod
50 def schedule(self, callback: Callable, *args, **kwargs) -> None: # pragma: NO COVER
51 """Schedule the callback to be called asynchronously.
52
53 Args:
54 callback: The function to call.
55 args: Positional arguments passed to the callback.
56 kwargs: Key-word arguments passed to the callback.
57
58 Returns:
59 None
60 """
61 raise NotImplementedError
62
63 @abc.abstractmethod
64 def shutdown(
65 self, await_msg_callbacks: bool = False
66 ) -> List["pubsub_v1.subscriber.message.Message"]: # pragma: NO COVER
67 """Shuts down the scheduler and immediately end all pending callbacks.
68
69 Args:
70 await_msg_callbacks:
71 If ``True``, the method will block until all currently executing
72 callbacks are done processing. If ``False`` (default), the
73 method will not wait for the currently running callbacks to complete.
74
75 Returns:
76 The messages submitted to the scheduler that were not yet dispatched
77 to their callbacks.
78 It is assumed that each message was submitted to the scheduler as the
79 first positional argument to the provided callback.
80 """
81 raise NotImplementedError
82
83
84def _make_default_thread_pool_executor() -> concurrent.futures.ThreadPoolExecutor:
85 return concurrent.futures.ThreadPoolExecutor(
86 max_workers=10, thread_name_prefix="ThreadPoolExecutor-ThreadScheduler"
87 )
88
89
90class ThreadScheduler(Scheduler):
91 """A thread pool-based scheduler. It must not be shared across
92 SubscriberClients.
93
94 This scheduler is useful in typical I/O-bound message processing.
95
96 Args:
97 executor:
98 An optional executor to use. If not specified, a default one
99 will be created.
100 """
101
102 def __init__(
103 self, executor: Optional[concurrent.futures.ThreadPoolExecutor] = None
104 ):
105 self._queue: queue.Queue = queue.Queue()
106 if executor is None:
107 self._executor = _make_default_thread_pool_executor()
108 else:
109 self._executor = executor
110
111 @property
112 def queue(self):
113 """Queue: A thread-safe queue used for communication between callbacks
114 and the scheduling thread."""
115 return self._queue
116
117 def schedule(self, callback: Callable, *args, **kwargs) -> None:
118 """Schedule the callback to be called asynchronously in a thread pool.
119
120 Args:
121 callback: The function to call.
122 args: Positional arguments passed to the callback.
123 kwargs: Key-word arguments passed to the callback.
124
125 Returns:
126 None
127 """
128 try:
129 self._executor.submit(callback, *args, **kwargs)
130 except RuntimeError:
131 warnings.warn(
132 "Scheduling a callback after executor shutdown.",
133 category=RuntimeWarning,
134 stacklevel=2,
135 )
136
137 def shutdown(
138 self, await_msg_callbacks: bool = False
139 ) -> List["pubsub_v1.subscriber.message.Message"]:
140 """Shut down the scheduler and immediately end all pending callbacks.
141
142 Args:
143 await_msg_callbacks:
144 If ``True``, the method will block until all currently executing
145 executor threads are done processing. If ``False`` (default), the
146 method will not wait for the currently running threads to complete.
147
148 Returns:
149 The messages submitted to the scheduler that were not yet dispatched
150 to their callbacks.
151 It is assumed that each message was submitted to the scheduler as the
152 first positional argument to the provided callback.
153 """
154 dropped_messages = []
155
156 # Drop all pending item from the executor. Without this, the executor will also
157 # try to process any pending work items before termination, which is undesirable.
158 #
159 # TODO: Replace the logic below by passing `cancel_futures=True` to shutdown()
160 # once we only need to support Python 3.9+.
161 try:
162 while True:
163 work_item = self._executor._work_queue.get(block=False)
164 if work_item is None: # Exceutor in shutdown mode.
165 continue
166
167 dropped_message = None
168 if sys.version_info < (3, 14):
169 # For Python < 3.14, work_item.args is a tuple of positional arguments.
170 # The message is expected to be the first argument.
171 if hasattr(work_item, "args") and work_item.args:
172 dropped_message = work_item.args[0] # type: ignore[index]
173 else:
174 # For Python >= 3.14, work_item.task is (fn, args, kwargs).
175 # The message is expected to be the first item in the args tuple (task[1]).
176 if (
177 hasattr(work_item, "task")
178 and len(work_item.task) == 3
179 and work_item.task[1]
180 ):
181 dropped_message = work_item.task[1][0]
182
183 if dropped_message is not None:
184 dropped_messages.append(dropped_message)
185 except queue.Empty:
186 pass
187
188 self._executor.shutdown(wait=await_msg_callbacks)
189 return dropped_messages