OAuth2Error.java

/*
 * Copyright 2022 Red Hat, Inc. and/or its affiliates
 * and other contributors as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.keycloak.utils;

import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;
import jakarta.ws.rs.BadRequestException;
import jakarta.ws.rs.ForbiddenException;
import jakarta.ws.rs.InternalServerErrorException;
import jakarta.ws.rs.NotAuthorizedException;
import jakarta.ws.rs.WebApplicationException;
import jakarta.ws.rs.core.Response;

import org.keycloak.OAuthErrorException;
import org.keycloak.models.RealmModel;
import org.keycloak.representations.idm.OAuth2ErrorRepresentation;
import org.keycloak.services.resources.Cors;

import static jakarta.ws.rs.core.HttpHeaders.WWW_AUTHENTICATE;

/**
 * @author <a href="mailto:dmitryt@backbase.com">Dmitry Telegin</a>
 */
public class OAuth2Error {

    private static final Map<Response.Status, Class<? extends WebApplicationException>> STATUS_MAP = new HashMap<>();

    private RealmModel realm;
    private String error;
    private String errorDescription;
    private Optional<Cors> cors = Optional.empty();

    private Class<? extends WebApplicationException> clazz;
    private Response.Status status;
    private boolean json = true;

    static {
        STATUS_MAP.put(Response.Status.BAD_REQUEST, BadRequestException.class);
        STATUS_MAP.put(Response.Status.UNAUTHORIZED, NotAuthorizedException.class);
        STATUS_MAP.put(Response.Status.FORBIDDEN, ForbiddenException.class);
        STATUS_MAP.put(Response.Status.INTERNAL_SERVER_ERROR, InternalServerErrorException.class);
    }

    public OAuth2Error realm(RealmModel realm) {
        this.realm = realm;
        return this;
    }

    public OAuth2Error error(String error) {

        this.error = error;

        switch (error) {
            case OAuthErrorException.INVALID_GRANT:
            case OAuthErrorException.INVALID_REQUEST:
            case OAuthErrorException.UNAUTHORIZED_CLIENT:
            case OAuthErrorException.UNSUPPORTED_GRANT_TYPE:
            case OAuthErrorException.INVALID_SCOPE:
                status = Response.Status.BAD_REQUEST;
                break;
            case OAuthErrorException.INVALID_CLIENT:
            case OAuthErrorException.INVALID_TOKEN:
                status = Response.Status.UNAUTHORIZED;
                break;
            case OAuthErrorException.INSUFFICIENT_SCOPE:
                status = Response.Status.FORBIDDEN;
                break;
            case OAuthErrorException.SERVER_ERROR:
                status = Response.Status.INTERNAL_SERVER_ERROR;
                break;
            default:
                throw new IllegalArgumentException("Unrecognized OAuth 2.0 error: " + error);
        }

        return this;
    }

    public OAuth2Error errorDescription(String errorDescription) {
        this.errorDescription = errorDescription;
        return this;
    }

    public OAuth2Error cors(Cors cors) {
        this.cors = Optional.ofNullable(cors);
        return this;
    }

    public OAuth2Error status(Response.Status status) {
        this.status = status;
        return this;
    }

    public OAuth2Error json(boolean json) {
        this.json = json;
        return this;
    }

    public WebApplicationException build() {
        clazz = STATUS_MAP.getOrDefault(status, WebApplicationException.class);
        Response.ResponseBuilder builder = Response.status(status);

        try {
            Constructor<? extends WebApplicationException> constructor = clazz.getConstructor(new Class[] { Response.class });
            cors.ifPresent(_cors -> { _cors.build(builder::header); });

            if (json) {
                OAuth2ErrorRepresentation errorRep = new OAuth2ErrorRepresentation(error, errorDescription);
                builder.entity(errorRep).type(MediaType.APPLICATION_JSON_TYPE);
            } else {
                WWWAuthenticate.BearerChallenge bearer = new WWWAuthenticate.BearerChallenge();
                bearer.setRealm(realm.getName());
                bearer.setError(error);
                bearer.setErrorDescription(errorDescription);
                WWWAuthenticate wwwAuthenticate = new WWWAuthenticate(bearer);
                wwwAuthenticate.build(builder::header);
                builder.entity("");
            }

            return constructor.newInstance(builder.build());
        } catch (NoSuchMethodException | SecurityException | InstantiationException | IllegalAccessException | IllegalArgumentException | InvocationTargetException ex) {
            throw new InternalServerErrorException(ex);
        }
    }

    public WebApplicationException insufficientScope(String errorDescription) {
        return this.error(OAuthErrorException.INSUFFICIENT_SCOPE).errorDescription(errorDescription).build();
    }

    public WebApplicationException invalidToken(String errorDescription) {
        return this.error(OAuthErrorException.INVALID_TOKEN).errorDescription(errorDescription).build();
    }

    public WebApplicationException invalidRequest(String errorDescription) {
        return this.error(OAuthErrorException.INVALID_REQUEST).errorDescription(errorDescription).build();
    }

    public WebApplicationException unauthorized() {
        return this.status(Response.Status.UNAUTHORIZED).build();
    }

    private static class WWWAuthenticate {

        private final List<Challenge> challenges;
        private Challenge master;
        private boolean singleHeader = true;

        public WWWAuthenticate(Challenge challenge, Challenge... moreChallenges) {
            challenges = new ArrayList<>(1 + ((moreChallenges == null) ? 0 : moreChallenges.length));
            challenges.add(challenge);
            if (moreChallenges != null) {
                challenges.addAll(Arrays.asList(moreChallenges));
            }
            master = challenge;
        }

        public void addChallenge(Challenge challenge) {
            challenges.add(challenge);
        }

        public void setMasterChallenge(Challenge challenge) {
            if (challenges.contains(challenge)) {
                master = challenge;
            } else {
                throw new IllegalArgumentException("Unknown challenge: " + challenge);
            }
        }

        public void setMasterChallenge(String scheme) {
            master = challenges.stream()
                .filter(c -> c.getScheme().equals(scheme))
                .findFirst()
                .orElseThrow(() -> new IllegalArgumentException("Unknown challenge: " + scheme));
        }

        public Challenge getMasterChallenge() {
            return master;
        }

        public boolean isSingleHeader() {
            return singleHeader;
        }

        public void setSingleHeader(boolean singleHeader) {
            this.singleHeader = singleHeader;
        }

        public void setAttribute(String attribute, String value) {
            challenges.forEach(c -> c.setAttribute(attribute, value));
        }

        public void build(BiConsumer<String, Object> addHeader) {
            if (singleHeader) {
                String header = challenges.stream()
                    .map(Challenge::toString)
                    .collect(Collectors.joining(", "));
                addHeader.accept(WWW_AUTHENTICATE, header);
            } else {
                challenges.forEach(c -> addHeader.accept(WWW_AUTHENTICATE, c));
            }
        }

        public static abstract class Challenge {

            private final Map<String, String> attributes = new LinkedHashMap<>();

            public void setAttribute(String attribute, String value) {
                if (value != null) {
                    attributes.put(attribute, value);
                }
            }

            public abstract String getScheme();

            @Override
            public String toString() {
                StringBuilder sb = new StringBuilder(getScheme());

                if (!attributes.isEmpty()) {
                    sb.append(" ").append(
                        attributes.entrySet().stream()
                            .map(e -> String.format("%s=\"%s\"", e.getKey(), e.getValue()))
                            .collect(Collectors.joining(", "))
                    );
                }

                return sb.toString();
            }

        }

        public static class BasicChallenge extends Challenge {

            private static final String BASIC_SCHEME = "Basic";
            private static final String REALM_ATTRIBUTE = "realm";

            public void setRealm(String realm) {
                setAttribute(REALM_ATTRIBUTE, realm);
            }

            @Override
            public String getScheme() {
                return BASIC_SCHEME;
            }

        }

        public static class BearerChallenge extends BasicChallenge {

            private static final String BEARER_SCHEME = "Bearer";

            private static final String ERROR_ATTRIBUTE = "error";
            private static final String ERROR_DESCRIPTION_ATTRIBUTE = "error_description";
            private static final String ERROR_URI_ATTRIBUTE = "error_uri";
            private static final String SCOPE_ATTRIBUTE = "scope";

            public void setError(String error) {
                setAttribute(ERROR_ATTRIBUTE, error);
            }

            public void setErrorDescription(String errorDescription) {
                setAttribute(ERROR_DESCRIPTION_ATTRIBUTE, errorDescription);
            }

            public void setErrorUri(String errorUri) {
                setAttribute(ERROR_URI_ATTRIBUTE, errorUri);
            }

            public void setScope(String scope) {
                setAttribute(SCOPE_ATTRIBUTE, scope);
            }

            @Override
            public String getScheme() {
                return BEARER_SCHEME;
            }

        }

    }

}