1# Copyright 2023 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import copy
16import logging
17import threading
18
19import google.auth.exceptions as e
20
21_LOGGER = logging.getLogger(__name__)
22
23
24class RefreshThreadManager:
25 """
26 Organizes exactly one background job that refresh a token.
27 """
28
29 def __init__(self):
30 """Initializes the manager."""
31
32 self._worker = None
33 self._lock = threading.Lock() # protects access to worker threads.
34
35 def start_refresh(self, cred, request):
36 """Starts a refresh thread for the given credentials.
37 The credentials are refreshed using the request parameter.
38 request and cred MUST not be None
39
40 Returns True if a background refresh was kicked off. False otherwise.
41
42 Args:
43 cred: A credentials object.
44 request: A request object.
45 Returns:
46 bool
47 """
48 if cred is None or request is None:
49 raise e.InvalidValue(
50 "Unable to start refresh. cred and request must be valid and instantiated objects."
51 )
52
53 with self._lock:
54 if self._worker is not None and self._worker._error_info is not None:
55 return False
56
57 if self._worker is None or not self._worker.is_alive(): # pragma: NO COVER
58 self._worker = RefreshThread(cred=cred, request=copy.deepcopy(request))
59 self._worker.start()
60 return True
61
62 def clear_error(self):
63 """
64 Removes any errors that were stored from previous background refreshes.
65 """
66 with self._lock:
67 if self._worker:
68 self._worker._error_info = None
69
70 def __getstate__(self):
71 """Pickle helper that serializes the _lock attribute."""
72 state = self.__dict__.copy()
73 state["_lock"] = None
74 return state
75
76 def __setstate__(self, state):
77 """Pickle helper that deserializes the _lock attribute."""
78 state["_lock"] = threading.Lock()
79 self.__dict__.update(state)
80
81
82class RefreshThread(threading.Thread):
83 """
84 Thread that refreshes credentials.
85 """
86
87 def __init__(self, cred, request, **kwargs):
88 """Initializes the thread.
89
90 Args:
91 cred: A Credential object to refresh.
92 request: A Request object used to perform a credential refresh.
93 **kwargs: Additional keyword arguments.
94 """
95
96 super().__init__(**kwargs)
97 self._cred = cred
98 self._request = request
99 self._error_info = None
100
101 def run(self):
102 """
103 Perform the credential refresh.
104 """
105 try:
106 self._cred.refresh(self._request)
107 except Exception as err: # pragma: NO COVER
108 _LOGGER.error(f"Background refresh failed due to: {err}")
109 self._error_info = err