TestModular.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.hadoop.examples.pi.math;

import java.math.BigInteger;
import java.util.Random;

import org.apache.hadoop.examples.pi.Util.Timer;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

public class TestModular{
  private static final Random RANDOM = new Random();
  private static final BigInteger TWO = BigInteger.valueOf(2);


  static final int DIV_VALID_BIT = 32;
  static final long DIV_LIMIT = 1L << DIV_VALID_BIT;

  // return r/n for n > r > 0
  static long div(long sum, long r, long n) {
    long q = 0;
    int i = DIV_VALID_BIT - 1;
    for(r <<= 1; r < n; r <<= 1) i--;
//System.out.printf("  r=%d, n=%d, q=%d\n", r, n, q);

    for(; i >= 0 ;) {
      r -= n;
      q |= (1L << i);
      if (r <= 0) break;
      for(; r < n; r <<= 1) i--;
//System.out.printf("  r=%d, n=%d, q=%d\n", r, n, q);
    }

    sum += q;
    return sum < DIV_LIMIT? sum: sum - DIV_LIMIT;
  }

  @Test
  void testDiv() {
    for (long n = 2; n < 100; n++)
      for (long r = 1; r < n; r++) {
        final long a = div(0, r, n);
        final long b = (long) ((r * 1.0 / n) * (1L << DIV_VALID_BIT));
        final String s = String.format("r=%d, n=%d, a=%X, b=%X", r, n, a, b);
        assertEquals(b, a, s);
      }
  }

  static long[][][] generateRN(int nsize, int rsize) {
    final long[][][] rn = new long[nsize][][];

    for(int i = 0; i < rn.length; i++) {
      rn[i] = new long[rsize + 1][];
      long n = RANDOM.nextLong() & 0xFFFFFFFFFFFFFFFL;
      if (n <= 1) n = 0xFFFFFFFFFFFFFFFL - n;
      rn[i][0] = new long[]{n};
      final BigInteger N = BigInteger.valueOf(n);

      for(int j = 1; j < rn[i].length; j++) {
        long r = RANDOM.nextLong();
        if (r < 0) r = -r;
        if (r >= n) r %= n;
        final BigInteger R = BigInteger.valueOf(r);
        rn[i][j] = new long[]{r, R.multiply(R).mod(N).longValue()};
      }
    }
    return rn;
  }

  static long square_slow(long z, final long n) {
    long r = 0;
    for(long s = z; z > 0; z >>= 1) {
      if ((((int)z) & 1) == 1) {
        r += s;
        if (r >= n) r -= n;
      }

      s <<= 1;
      if (s >= n) s -= n;
    }
    return r;
  }

  //0 <= r < n < max/2
  static long square(long r, final long n, long r2p64) {
    if (r <= Modular.MAX_SQRT_LONG) {
      r *= r;
      if (r >= n) r %= n;
    } else {
      final int HALF = (63 - Long.numberOfLeadingZeros(n)) >> 1;
      final int FULL = HALF << 1;
      final long ONES = (1 << HALF) - 1;

      final long high = r >>> HALF;
      final long low  = r &= ONES;

      r *= r;
      if (r >= n) r %= n;

      if (high != 0) {
        long s = high * high;
        if (s >= n) s %= n;
        for(int i = 0; i < FULL; i++)
          if ((s <<= 1) >= n) s -= n;

        if (low == 0)
          r = s;
        else {
          long t = high * low;
          if (t >= n) t %= n;
          for(int i = -1; i < HALF; i++)
            if ((t <<= 1) >= n) t -= n;

          r += s;
          if (r >= n) r -= n;
          r += t;
          if (r >= n) r -= n;
        }
      }
    }
    return r;
  }

  static void squareBenchmarks() {
    final Timer t = new Timer(false);
    t.tick("squareBenchmarks(), MAX_SQRT=" + Modular.MAX_SQRT_LONG);

    final long[][][] rn = generateRN(1000, 1000);
    t.tick("generateRN");

    for(int i = 0; i < rn.length; i++) {
      final long n = rn[i][0][0];
      for(int j = 1; j < rn[i].length; j++) {
        final long r = rn[i][j][0];
        final long answer = rn[i][j][1];
        final long s = square_slow(r, n);
        if (s != answer) {
          assertEquals(
              answer, s, "r=" + r + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("square_slow");

    for(int i = 0; i < rn.length; i++) {
      final long n = rn[i][0][0];
      long r2p64 = (0x4000000000000000L % n) << 1;
      if (r2p64 >= n) r2p64 -= n;
      for(int j = 1; j < rn[i].length; j++) {
        final long r = rn[i][j][0];
        final long answer = rn[i][j][1];
        final long s = square(r, n, r2p64);
        if (s != answer) {
          assertEquals(
              answer, s, "r=" + r + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("square");

    for(int i = 0; i < rn.length; i++) {
      final long n = rn[i][0][0];
      final BigInteger N = BigInteger.valueOf(n);
      for(int j = 1; j < rn[i].length; j++) {
        final long r = rn[i][j][0];
        final long answer = rn[i][j][1];
        final BigInteger R = BigInteger.valueOf(r);
        final long s = R.multiply(R).mod(N).longValue();
        if (s != answer) {
          assertEquals(
              answer, s, "r=" + r + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("R.multiply(R).mod(N)");

    for(int i = 0; i < rn.length; i++) {
      final long n = rn[i][0][0];
      final BigInteger N = BigInteger.valueOf(n);
      for(int j = 1; j < rn[i].length; j++) {
        final long r = rn[i][j][0];
        final long answer = rn[i][j][1];
        final BigInteger R = BigInteger.valueOf(r);
        final long s = R.modPow(TWO, N).longValue();
        if (s != answer) {
          assertEquals(
              answer, s, "r=" + r + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("R.modPow(TWO, N)");
  }

  static long[][][] generateEN(int nsize, int esize) {
    final long[][][] en = new long[nsize][][];

    for(int i = 0; i < en.length; i++) {
      en[i] = new long[esize + 1][];
      long n = (RANDOM.nextLong() & 0xFFFFFFFFFFFFFFFL) | 1L;
      if (n == 1) n = 3;
      en[i][0] = new long[]{n};
      final BigInteger N = BigInteger.valueOf(n);

      for(int j = 1; j < en[i].length; j++) {
        long e = RANDOM.nextLong();
        if (e < 0) e = -e;
        final BigInteger E = BigInteger.valueOf(e);
        en[i][j] = new long[]{e, TWO.modPow(E, N).longValue()};
      }
    }
    return en;
  }

  /** Compute $2^e \mod n$ for e > 0, n > 2 */
  static long modBigInteger(final long e, final long n) {
    long mask = (e & 0xFFFFFFFF00000000L) == 0 ? 0x00000000FFFFFFFFL
        : 0xFFFFFFFF00000000L;
    mask &= (e & 0xFFFF0000FFFF0000L & mask) == 0 ? 0x0000FFFF0000FFFFL
        : 0xFFFF0000FFFF0000L;
    mask &= (e & 0xFF00FF00FF00FF00L & mask) == 0 ? 0x00FF00FF00FF00FFL
        : 0xFF00FF00FF00FF00L;
    mask &= (e & 0xF0F0F0F0F0F0F0F0L & mask) == 0 ? 0x0F0F0F0F0F0F0F0FL
        : 0xF0F0F0F0F0F0F0F0L;
    mask &= (e & 0xCCCCCCCCCCCCCCCCL & mask) == 0 ? 0x3333333333333333L
        : 0xCCCCCCCCCCCCCCCCL;
    mask &= (e & 0xAAAAAAAAAAAAAAAAL & mask) == 0 ? 0x5555555555555555L
        : 0xAAAAAAAAAAAAAAAAL;

    final BigInteger N = BigInteger.valueOf(n);
    long r = 2;
    for (mask >>= 1; mask > 0; mask >>= 1) {
      if (r <= Modular.MAX_SQRT_LONG) {
        r *= r;
        if (r >= n) r %= n;
      } else {
        final BigInteger R = BigInteger.valueOf(r);
        r = R.multiply(R).mod(N).longValue();
      }

      if ((e & mask) != 0) {
        r <<= 1;
        if (r >= n) r -= n;
      }
    }
    return r;
  }

  static class Montgomery2 extends Montgomery {
    /** Compute 2^y mod N for N odd. */
    long mod2(final long y) {
      long r0 = R - N;
      long r1 = r0 << 1;
      if (r1 >= N) r1 -= N;

      for(long mask = Long.highestOneBit(y); mask > 0; mask >>>= 1) {
        if ((mask & y) == 0) {
          r1 = product.m(r0, r1);
          r0 = product.m(r0, r0);
        } else {
          r0 = product.m(r0, r1);
          r1 = product.m(r1, r1);
        }
      }
      return product.m(r0, 1);
    }
  }

  static void modBenchmarks() {
    final Timer t = new Timer(false);
    t.tick("modBenchmarks()");

    final long[][][] en = generateEN(10000, 10);
    t.tick("generateEN");

    for(int i = 0; i < en.length; i++) {
      final long n = en[i][0][0];
      for(int j = 1; j < en[i].length; j++) {
        final long e = en[i][j][0];
        final long answer = en[i][j][1];
        final long s = Modular.mod(e, n);
        if (s != answer) {
          assertEquals(
              answer, s, "e=" + e + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("Modular.mod");

    final Montgomery2 m2 = new Montgomery2();
    for(int i = 0; i < en.length; i++) {
      final long n = en[i][0][0];
      m2.set(n);
      for(int j = 1; j < en[i].length; j++) {
        final long e = en[i][j][0];
        final long answer = en[i][j][1];
        final long s = m2.mod(e);
        if (s != answer) {
          assertEquals(
              answer, s, "e=" + e + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("montgomery.mod");

    for(int i = 0; i < en.length; i++) {
      final long n = en[i][0][0];
      m2.set(n);
      for(int j = 1; j < en[i].length; j++) {
        final long e = en[i][j][0];
        final long answer = en[i][j][1];
        final long s = m2.mod2(e);
        if (s != answer) {
          assertEquals(
              answer, s, "e=" + e + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("montgomery.mod2");

    for(int i = 0; i < en.length; i++) {
      final long n = en[i][0][0];
      final BigInteger N = BigInteger.valueOf(n);
      for(int j = 1; j < en[i].length; j++) {
        final long e = en[i][j][0];
        final long answer = en[i][j][1];
        final long s = TWO.modPow(BigInteger.valueOf(e), N).longValue();
        if (s != answer) {
          assertEquals(
              answer, s, "e=" + e + ", n=" + n + ", answer=" + answer + " but s=" + s);
        }
      }
    }
    t.tick("BigInteger.modPow(e, n)");
  }

  public static void main(String[] args) {
    squareBenchmarks();
    modBenchmarks();
  }
}