IPMatchBase.java

/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2023 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */
package io.undertow.predicate.ip;

import java.io.IOException;
import java.net.Inet4Address;
import java.net.Inet6Address;
import java.net.InetAddress;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.regex.Pattern;

import org.xnio.Bits;

import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.util.NetworkUtils;

/**
 * Base class for IP related matching.
 *
 * @author Stuart Douglas
 * @author baranowb
 */
public abstract class IPMatchBase<T extends IPMatchBase>{
    /**
     * Standard IP address
     */
    protected static final Pattern IP4_EXACT = Pattern.compile(NetworkUtils.IP4_EXACT);

    /**
     * Standard IP address, with some octets replaced by a '*'
     */
    protected static final Pattern IP4_WILDCARD = Pattern.compile("(?:(?:\\d{1,3}|\\*)\\.){3}(?:\\d{1,3}|\\*)");

    /**
     * IPv4 address with subnet specified via slash notation
     */
    protected static final Pattern IP4_SLASH = Pattern.compile("(?:\\d{1,3}\\.){3}\\d{1,3}\\/\\d\\d?");

    /**
     * Standard full IPv6 address
     */
    protected static final Pattern IP6_EXACT = Pattern.compile(NetworkUtils.IP6_EXACT);

    /**
     * Standard full IPv6 address, with some parts replaced by a '*'
     */
    protected static final Pattern IP6_WILDCARD = Pattern.compile("(?:(?:[a-zA-Z0-9]{1,4}|\\*):){7}(?:[a-zA-Z0-9]{1,4}|\\*)");

    /**
     * Standard full IPv6 address with subnet specified via slash notation
     */
    protected static final Pattern IP6_SLASH = Pattern.compile("(?:[a-zA-Z0-9]{1,4}:){7}[a-zA-Z0-9]{1,4}\\/\\d{1,3}");

    protected static final boolean traceEnabled;
    protected static final boolean debugEnabled;

    static {
        traceEnabled = UndertowLogger.PREDICATE_LOGGER.isTraceEnabled();
        debugEnabled = UndertowLogger.PREDICATE_LOGGER.isDebugEnabled();
    }

    protected final List<PeerMatch> ipv6Matches = new CopyOnWriteArrayList<>();
    protected final List<PeerMatch> ipv4Matches = new CopyOnWriteArrayList<>();
    protected volatile boolean defaultAllow = false;

    public IPMatchBase(final boolean defaultAllow) {
        super();
        this.defaultAllow = defaultAllow;
    }

    public boolean isAllowed(InetAddress address) {
        if (address instanceof Inet4Address) {
            for (PeerMatch rule : ipv4Matches) {
                if (traceEnabled) {
                    UndertowLogger.PREDICATE_LOGGER.tracef("Comparing rule [%s] to IPv4 address %s.", rule.toPredicateString(), address.getHostAddress());
                }
                if (rule.matches(address)) {
                    return !rule.isDeny();
                }
            }
        } else if (address instanceof Inet6Address) {
            for (PeerMatch rule : ipv6Matches) {
                if (traceEnabled) {
                    UndertowLogger.PREDICATE_LOGGER.tracef("Comparing rule [%s] to IPv6 address %s.", rule.toPredicateString(), address.getHostAddress());
                }
                if (rule.matches(address)) {
                    return !rule.isDeny();
                }
            }
        }
        return defaultAllow;
    }

    protected T addRule(final String peer, final boolean deny) {
        if (IP4_EXACT.matcher(peer).matches()) {
             addIpV4ExactMatch(peer, deny);
        } else if (IP4_WILDCARD.matcher(peer).matches()) {
            addIpV4WildcardMatch(peer, deny);
        } else if (IP4_SLASH.matcher(peer).matches()) {
            addIpV4SlashPrefix(peer, deny);
        } else if (IP6_EXACT.matcher(peer).matches()) {
            addIpV6ExactMatch(peer, deny);
        } else if (IP6_WILDCARD.matcher(peer).matches()) {
            addIpV6WildcardMatch(peer, deny);
        } else if (IP6_SLASH.matcher(peer).matches()) {
            addIpV6SlashPrefix(peer, deny);
        } else {
            throw UndertowMessages.MESSAGES.notAValidIpPattern(peer);
        }
        return (T) this;
    }

    /**
     * Adds an allowed peer to the ACL list
     * <p>
     * Peer can take several forms:
     * <p>
     * a.b.c.d = Literal IPv4 Address
     * a:b:c:d:e:f:g:h = Literal IPv6 Address
     * a.b.* = Wildcard IPv4 Address
     * a:b:* = Wildcard IPv6 Address
     * a.b.c.0/24 = Classless wildcard IPv4 address
     * a:b:c:d:e:f:g:0/120 = Classless wildcard IPv6 address
     *
     * @param peer The peer to add to the ACL
     */
    public T addAllow(final String peer) {
        return addRule(peer, false);
    }

    /**
     * Adds an denied peer to the ACL list
     * <p>
     * Peer can take several forms:
     * <p>
     * a.b.c.d = Literal IPv4 Address
     * a:b:c:d:e:f:g:h = Literal IPv6 Address
     * a.b.* = Wildcard IPv4 Address
     * a:b:* = Wildcard IPv6 Address
     * a.b.c.0/24 = Classless wildcard IPv4 address
     * a:b:c:d:e:f:g:0/120 = Classless wildcard IPv6 address
     *
     * @param peer The peer to add to the ACL
     */
    public T addDeny(final String peer) {
        return addRule(peer, true);
    }

    public T clearRules() {
        this.ipv4Matches.clear();
        this.ipv6Matches.clear();
        return (T) this;
    }
    public T setDefaultAllow(final boolean defaultAllow) {
        this.defaultAllow = defaultAllow;
        return (T) this;
    }

    public boolean isDefaultAllow() {
        return defaultAllow;
    }

    private void addIpV6SlashPrefix(final String peer, final boolean deny) {
        String[] components = peer.split("\\/");
        String[] parts = components[0].split("\\:");
        int maskLen = Integer.parseInt(components[1]);
        assert parts.length == 8;

        byte[] pattern = new byte[16];
        byte[] mask = new byte[16];

        for (int i = 0; i < 8; ++i) {
            int val = Integer.parseInt(parts[i], 16);
            pattern[i * 2] = (byte) (val >> 8);
            pattern[i * 2 + 1] = (byte) (val & 0xFF);
        }
        for (int i = 0; i < 16; ++i) {
            if (maskLen > 8) {
                mask[i] = (byte) (0xFF);
                maskLen -= 8;
            } else if (maskLen != 0) {
                mask[i] = (byte) (Bits.intBitMask(8 - maskLen, 7) & 0xFF);
                maskLen = 0;
            } else {
                break;
            }
        }
        ipv6Matches.add(new PrefixIpV6PeerMatch(peer, mask, pattern, deny));
    }

    private void addIpV4SlashPrefix(final String peer, final boolean deny) {
        String[] components = peer.split("\\/");
        String[] parts = components[0].split("\\.");
        int maskLen = Integer.parseInt(components[1]);
        final int mask = Bits.intBitMask(32 - maskLen, 31);
        int prefix = 0;
        for (int i = 0; i < 4; ++i) {
            prefix <<= 8;
            String part = parts[i];
            int no = Integer.parseInt(part);
            prefix |= no;
        }
        prefix &= mask;

        ipv4Matches.add(new PrefixIpV4PeerMatch(peer, mask, prefix, deny));

    }

    private void addIpV6WildcardMatch(final String peer, boolean deny) {
        byte[] pattern = new byte[16];
        byte[] mask = new byte[16];
        String[] parts = peer.split("\\:");
        assert parts.length == 8;
        for (int i = 0; i < 8; ++i) {
            if (!parts[i].equals("*")) {
                int val = Integer.parseInt(parts[i], 16);
                pattern[i * 2] = (byte) (val >> 8);
                pattern[i * 2 + 1] = (byte) (val & 0xFF);
                mask[i * 2] = (byte) (0xFF);
                mask[i * 2 + 1] = (byte) (0xFF);
            }
        }
        ipv6Matches.add(new PrefixIpV6PeerMatch(peer, mask, pattern, deny));

    }

    private void addIpV4WildcardMatch(final String peer, final boolean deny) {
        String[] parts = peer.split("\\.");
        int mask = 0;
        int prefix = 0;
        for (int i = 0; i < 4; ++i) {
            mask <<= 8;
            prefix <<= 8;
            String part = parts[i];
            if (!part.equals("*")) {
                int no = Integer.parseInt(part);
                mask |= 0xFF;
                prefix |= no;
            }
        }
        ipv4Matches.add(new PrefixIpV4PeerMatch(peer, mask, prefix, deny));
    }

    private void addIpV6ExactMatch(final String peer, final boolean deny) {
        byte[] bytes;
        try {
            bytes = NetworkUtils.parseIpv6AddressToBytes(peer);
            ipv6Matches.add(new ExactIpV6PeerMatch(peer, bytes, deny));
        } catch (IOException e) {
            throw UndertowMessages.MESSAGES.invalidACLAddress(e);
        }
    }

    private void addIpV4ExactMatch(final String peer, final boolean deny) {
        String[] parts = peer.split("\\.");
        byte[] bytes = {(byte) Integer.parseInt(parts[0]), (byte) Integer.parseInt(parts[1]), (byte) Integer.parseInt(parts[2]), (byte) Integer.parseInt(parts[3])};
        ipv4Matches.add(new ExactIpV4PeerMatch(peer, bytes, deny));
    }

    protected abstract static class PeerMatch {

        private final String pattern;
        protected final boolean deny;
        protected PeerMatch(final String pattern, final boolean deny) {
            this.pattern = pattern;
            this.deny = deny;
        }

        public abstract boolean matches(InetAddress address);

        public boolean isDeny() {
            return this.deny;
        }

        @Override
        public String toString() {
            return getClass().getSimpleName() + "{"
                    + "deny=" + deny
                    + ", pattern='" + pattern + '\''
                    + '}';
        }

        public String toPredicateString() {
            return pattern + " " + (deny ? "deny" : "allow");
        }
    }

    protected static class ExactIpV4PeerMatch extends PeerMatch {

        private final byte[] address;

        protected ExactIpV4PeerMatch(final String pattern, final byte[] address, final boolean deny) {
            super(pattern, deny);
            this.address = address;
        }

        @Override
        public boolean matches(final InetAddress address) {
            return Arrays.equals(address.getAddress(), this.address);
        }
    }

    protected static class ExactIpV6PeerMatch extends PeerMatch {

        private final byte[] address;

        protected ExactIpV6PeerMatch(final String pattern, final byte[] address, final boolean deny){
            super(pattern, deny);
            this.address = address;
        }

        @Override
        public boolean matches(final InetAddress address) {
            return Arrays.equals(address.getAddress(), this.address);
        }
    }

    protected static class PrefixIpV4PeerMatch extends PeerMatch {

        private final int mask;
        private final int prefix;

        protected PrefixIpV4PeerMatch(final String pattern, final int mask, final int prefix, final boolean deny) {
            super(pattern, deny);
            this.mask = mask;
            this.prefix = prefix;
        }

        @Override
        public boolean matches(final InetAddress address) {
            byte[] bytes = address.getAddress();
            if (bytes == null) {
                return false;
            }
            int addressInt = ((bytes[0] & 0xFF) << 24) | ((bytes[1] & 0xFF) << 16) | ((bytes[2] & 0xFF) << 8) | (bytes[3] & 0xFF);
            return (addressInt & mask) == prefix;
        }
    }

    protected static class PrefixIpV6PeerMatch extends PeerMatch {

        private final byte[] mask;
        private final byte[] prefix;

        protected PrefixIpV6PeerMatch(final String pattern, final byte[] mask, final byte[] prefix, final boolean deny) {
            super(pattern, deny);
            this.mask = mask;
            this.prefix = prefix;
            assert mask.length == prefix.length;
        }

        @Override
        public boolean matches(final InetAddress address) {
            byte[] bytes = address.getAddress();
            if (bytes == null) {
                return false;
            }
            if (bytes.length != mask.length) {
                return false;
            }
            for (int i = 0; i < mask.length; ++i) {
                if ((bytes[i] & mask[i]) != prefix[i]) {
                    return false;
                }
            }
            return true;
        }
    }

}