1""" 
    2oauthlib.oauth2.rfc6749.grant_types 
    3~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 
    4""" 
    5import logging 
    6from itertools import chain 
    7 
    8from oauthlib.common import add_params_to_uri 
    9from oauthlib.oauth2.rfc6749 import errors, utils 
    10from oauthlib.uri_validate import is_absolute_uri 
    11 
    12from ..request_validator import RequestValidator 
    13from ..utils import is_secure_transport 
    14 
    15log = logging.getLogger(__name__) 
    16 
    17 
    18class ValidatorsContainer: 
    19    """ 
    20    Container object for holding custom validator callables to be invoked 
    21    as part of the grant type `validate_authorization_request()` or 
    22    `validate_authorization_request()` methods on the various grant types. 
    23 
    24    Authorization validators must be callables that take a request object and 
    25    return a dict, which may contain items to be added to the `request_info` 
    26    returned from the grant_type after validation. 
    27 
    28    Token validators must be callables that take a request object and 
    29    return None. 
    30 
    31    Both authorization validators and token validators may raise OAuth2 
    32    exceptions if validation conditions fail. 
    33 
    34    Authorization validators added to `pre_auth` will be run BEFORE 
    35    the standard validations (but after the critical ones that raise 
    36    fatal errors) as part of `validate_authorization_request()` 
    37 
    38    Authorization validators added to `post_auth` will be run AFTER 
    39    the standard validations as part of `validate_authorization_request()` 
    40 
    41    Token validators added to `pre_token` will be run BEFORE 
    42    the standard validations as part of `validate_token_request()` 
    43 
    44    Token validators added to `post_token` will be run AFTER 
    45    the standard validations as part of `validate_token_request()` 
    46 
    47    For example: 
    48 
    49    >>> def my_auth_validator(request): 
    50    ...    return {'myval': True} 
    51    >>> auth_code_grant = AuthorizationCodeGrant(request_validator) 
    52    >>> auth_code_grant.custom_validators.pre_auth.append(my_auth_validator) 
    53    >>> def my_token_validator(request): 
    54    ...     if not request.everything_okay: 
    55    ...         raise errors.OAuth2Error("uh-oh") 
    56    >>> auth_code_grant.custom_validators.post_token.append(my_token_validator) 
    57    """ 
    58 
    59    def __init__(self, post_auth, post_token, 
    60                 pre_auth, pre_token): 
    61        self.pre_auth = pre_auth 
    62        self.post_auth = post_auth 
    63        self.pre_token = pre_token 
    64        self.post_token = post_token 
    65 
    66    @property 
    67    def all_pre(self): 
    68        return chain(self.pre_auth, self.pre_token) 
    69 
    70    @property 
    71    def all_post(self): 
    72        return chain(self.post_auth, self.post_token) 
    73 
    74 
    75class GrantTypeBase: 
    76    error_uri = None 
    77    request_validator = None 
    78    default_response_mode = 'fragment' 
    79    refresh_token = True 
    80    response_types = ['code'] 
    81 
    82    def __init__(self, request_validator=None, **kwargs): 
    83        self.request_validator = request_validator or RequestValidator() 
    84 
    85        # Transforms class variables into instance variables: 
    86        self.response_types = self.response_types 
    87        self.refresh_token = self.refresh_token 
    88        self._setup_custom_validators(kwargs) 
    89        self._code_modifiers = [] 
    90        self._token_modifiers = [] 
    91 
    92        for kw, val in kwargs.items(): 
    93            setattr(self, kw, val) 
    94 
    95    def _setup_custom_validators(self, kwargs): 
    96        post_auth = kwargs.get('post_auth', []) 
    97        post_token = kwargs.get('post_token', []) 
    98        pre_auth = kwargs.get('pre_auth', []) 
    99        pre_token = kwargs.get('pre_token', []) 
    100        if not hasattr(self, 'validate_authorization_request'): 
    101            if post_auth or pre_auth: 
    102                msg = ("{} does not support authorization validators. Use " 
    103                       "token validators instead.").format(self.__class__.__name__) 
    104                raise ValueError(msg) 
    105            # Using tuples here because they can't be appended to: 
    106            post_auth, pre_auth = (), () 
    107        self.custom_validators = ValidatorsContainer(post_auth, post_token, 
    108                                                     pre_auth, pre_token) 
    109 
    110    def register_response_type(self, response_type): 
    111        self.response_types.append(response_type) 
    112 
    113    def register_code_modifier(self, modifier): 
    114        self._code_modifiers.append(modifier) 
    115 
    116    def register_token_modifier(self, modifier): 
    117        self._token_modifiers.append(modifier) 
    118 
    119    def create_authorization_response(self, request, token_handler): 
    120        """ 
    121        :param request: OAuthlib request. 
    122        :type request: oauthlib.common.Request 
    123        :param token_handler: A token handler instance, for example of type 
    124                              oauthlib.oauth2.BearerToken. 
    125        """ 
    126        raise NotImplementedError('Subclasses must implement this method.') 
    127 
    128    def create_token_response(self, request, token_handler): 
    129        """ 
    130        :param request: OAuthlib request. 
    131        :type request: oauthlib.common.Request 
    132        :param token_handler: A token handler instance, for example of type 
    133                              oauthlib.oauth2.BearerToken. 
    134        """ 
    135        raise NotImplementedError('Subclasses must implement this method.') 
    136 
    137    def add_token(self, token, token_handler, request): 
    138        """ 
    139        :param token: 
    140        :param token_handler: A token handler instance, for example of type 
    141                              oauthlib.oauth2.BearerToken. 
    142        :param request: OAuthlib request. 
    143        :type request: oauthlib.common.Request 
    144        """ 
    145        # Only add a hybrid access token on auth step if asked for 
    146        if request.response_type not in ["token", "code token", "id_token token", "code id_token token"]: 
    147            return token 
    148 
    149        token.update(token_handler.create_token(request, refresh_token=False)) 
    150        return token 
    151 
    152    def validate_grant_type(self, request): 
    153        """ 
    154        :param request: OAuthlib request. 
    155        :type request: oauthlib.common.Request 
    156        """ 
    157        client_id = getattr(request, 'client_id', None) 
    158        if not self.request_validator.validate_grant_type(client_id, 
    159                                                          request.grant_type, request.client, request): 
    160            log.debug('Unauthorized from %r (%r) access to grant type %s.', 
    161                      request.client_id, request.client, request.grant_type) 
    162            raise errors.UnauthorizedClientError(request=request) 
    163 
    164    def validate_scopes(self, request): 
    165        """ 
    166        :param request: OAuthlib request. 
    167        :type request: oauthlib.common.Request 
    168        """ 
    169        if not request.scopes: 
    170            request.scopes = utils.scope_to_list(request.scope) or utils.scope_to_list( 
    171                self.request_validator.get_default_scopes(request.client_id, request)) 
    172        log.debug('Validating access to scopes %r for client %r (%r).', 
    173                  request.scopes, request.client_id, request.client) 
    174        if not self.request_validator.validate_scopes(request.client_id, 
    175                                                      request.scopes, request.client, request): 
    176            raise errors.InvalidScopeError(request=request) 
    177 
    178    def prepare_authorization_response(self, request, token, headers, body, status): 
    179        """Place token according to response mode. 
    180 
    181        Base classes can define a default response mode for their authorization 
    182        response by overriding the static `default_response_mode` member. 
    183 
    184        :param request: OAuthlib request. 
    185        :type request: oauthlib.common.Request 
    186        :param token: 
    187        :param headers: 
    188        :param body: 
    189        :param status: 
    190        """ 
    191        request.response_mode = request.response_mode or self.default_response_mode 
    192 
    193        if request.response_mode not in ('query', 'fragment'): 
    194            log.debug('Overriding invalid response mode %s with %s', 
    195                      request.response_mode, self.default_response_mode) 
    196            request.response_mode = self.default_response_mode 
    197 
    198        token_items = token.items() 
    199 
    200        if request.response_type == 'none': 
    201            state = token.get('state', None) 
    202            token_items = [('state', state)] if state else [] 
    203 
    204        if request.response_mode == 'query': 
    205            headers['Location'] = add_params_to_uri( 
    206                request.redirect_uri, token_items, fragment=False) 
    207            return headers, body, status 
    208 
    209        if request.response_mode == 'fragment': 
    210            headers['Location'] = add_params_to_uri( 
    211                request.redirect_uri, token_items, fragment=True) 
    212            return headers, body, status 
    213 
    214        raise NotImplementedError( 
    215            'Subclasses must set a valid default_response_mode') 
    216 
    217    def _get_default_headers(self): 
    218        """Create default headers for grant responses.""" 
    219        return { 
    220            'Content-Type': 'application/json', 
    221            'Cache-Control': 'no-store', 
    222            'Pragma': 'no-cache', 
    223        } 
    224 
    225    def _handle_redirects(self, request): 
    226        if request.redirect_uri is not None: 
    227            request.using_default_redirect_uri = False 
    228            log.debug('Using provided redirect_uri %s', request.redirect_uri) 
    229            if not is_absolute_uri(request.redirect_uri): 
    230                raise errors.InvalidRedirectURIError(request=request) 
    231 
    232            # The authorization server MUST verify that the redirection URI 
    233            # to which it will redirect the access token matches a 
    234            # redirection URI registered by the client as described in 
    235            # Section 3.1.2. 
    236            # https://tools.ietf.org/html/rfc6749#section-3.1.2 
    237            if not self.request_validator.validate_redirect_uri( 
    238                    request.client_id, request.redirect_uri, request): 
    239                raise errors.MismatchingRedirectURIError(request=request) 
    240        else: 
    241            request.redirect_uri = self.request_validator.get_default_redirect_uri( 
    242                request.client_id, request) 
    243            request.using_default_redirect_uri = True 
    244            log.debug('Using default redirect_uri %s.', request.redirect_uri) 
    245            if not request.redirect_uri: 
    246                raise errors.MissingRedirectURIError(request=request) 
    247            if not is_absolute_uri(request.redirect_uri): 
    248                raise errors.InvalidRedirectURIError(request=request) 
    249 
    250    def _create_cors_headers(self, request): 
    251        """If CORS is allowed, create the appropriate headers.""" 
    252        if 'origin' not in request.headers: 
    253            return {} 
    254 
    255        origin = request.headers['origin'] 
    256        if not is_secure_transport(origin): 
    257            log.debug('Origin "%s" is not HTTPS, CORS not allowed.', origin) 
    258            return {} 
    259        elif not self.request_validator.is_origin_allowed( 
    260            request.client_id, origin, request): 
    261            log.debug('Invalid origin "%s", CORS not allowed.', origin) 
    262            return {} 
    263        else: 
    264            log.debug('Valid origin "%s", injecting CORS headers.', origin) 
    265            return {'Access-Control-Allow-Origin': origin}