1# Copyright 2023 Google LLC 
    2# 
    3# Licensed under the Apache License, Version 2.0 (the "License"); 
    4# you may not use this file except in compliance with the License. 
    5# You may obtain a copy of the License at 
    6# 
    7#      http://www.apache.org/licenses/LICENSE-2.0 
    8# 
    9# Unless required by applicable law or agreed to in writing, software 
    10# distributed under the License is distributed on an "AS IS" BASIS, 
    11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
    12# See the License for the specific language governing permissions and 
    13# limitations under the License. 
    14 
    15import copy 
    16import logging 
    17import threading 
    18 
    19import google.auth.exceptions as e 
    20 
    21_LOGGER = logging.getLogger(__name__) 
    22 
    23 
    24class RefreshThreadManager: 
    25    """ 
    26    Organizes exactly one background job that refresh a token. 
    27    """ 
    28 
    29    def __init__(self): 
    30        """Initializes the manager.""" 
    31 
    32        self._worker = None 
    33        self._lock = threading.Lock()  # protects access to worker threads. 
    34 
    35    def start_refresh(self, cred, request): 
    36        """Starts a refresh thread for the given credentials. 
    37        The credentials are refreshed using the request parameter. 
    38        request and cred MUST not be None 
    39 
    40        Returns True if a background refresh was kicked off. False otherwise. 
    41 
    42        Args: 
    43            cred: A credentials object. 
    44            request: A request object. 
    45        Returns: 
    46          bool 
    47        """ 
    48        if cred is None or request is None: 
    49            raise e.InvalidValue( 
    50                "Unable to start refresh. cred and request must be valid and instantiated objects." 
    51            ) 
    52 
    53        with self._lock: 
    54            if self._worker is not None and self._worker._error_info is not None: 
    55                return False 
    56 
    57            if self._worker is None or not self._worker.is_alive():  # pragma: NO COVER 
    58                self._worker = RefreshThread(cred=cred, request=copy.deepcopy(request)) 
    59                self._worker.start() 
    60        return True 
    61 
    62    def clear_error(self): 
    63        """ 
    64      Removes any errors that were stored from previous background refreshes. 
    65      """ 
    66        with self._lock: 
    67            if self._worker: 
    68                self._worker._error_info = None 
    69 
    70    def __getstate__(self): 
    71        """Pickle helper that serializes the _lock attribute.""" 
    72        state = self.__dict__.copy() 
    73        state["_lock"] = None 
    74        return state 
    75 
    76    def __setstate__(self, state): 
    77        """Pickle helper that deserializes the _lock attribute.""" 
    78        state["_lock"] = threading.Lock() 
    79        self.__dict__.update(state) 
    80 
    81 
    82class RefreshThread(threading.Thread): 
    83    """ 
    84    Thread that refreshes credentials. 
    85    """ 
    86 
    87    def __init__(self, cred, request, **kwargs): 
    88        """Initializes the thread. 
    89 
    90        Args: 
    91            cred: A Credential object to refresh. 
    92            request: A Request object used to perform a credential refresh. 
    93            **kwargs: Additional keyword arguments. 
    94        """ 
    95 
    96        super().__init__(**kwargs) 
    97        self._cred = cred 
    98        self._request = request 
    99        self._error_info = None 
    100 
    101    def run(self): 
    102        """ 
    103        Perform the credential refresh. 
    104        """ 
    105        try: 
    106            self._cred.refresh(self._request) 
    107        except Exception as err:  # pragma: NO COVER 
    108            _LOGGER.error(f"Background refresh failed due to: {err}") 
    109            self._error_info = err