AbstractMatrixTest.java

/**
 * Copyright (c) 2017, RTE (http://www.rte-france.com)
 * This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/.
 * SPDX-License-Identifier: MPL-2.0
 */
package com.powsybl.math.matrix;

import com.google.common.collect.ImmutableList;
import com.google.common.testing.EqualsTester;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.PrintStream;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

/**
 * @author Geoffroy Jamgotchian {@literal <geoffroy.jamgotchian at rte-france.com>}
 */
abstract class AbstractMatrixTest {

    protected static final double EPSILON = Math.pow(10, -15);

    protected abstract MatrixFactory getMatrixFactory();

    protected abstract MatrixFactory getOtherMatrixFactory();

    protected Matrix createA(MatrixFactory matrixFactory) {
        Matrix a = matrixFactory.create(3, 2, 3);
        a.set(0, 0, 1);
        a.set(2, 0, 2);
        a.set(1, 1, 3);
        return a;
    }

    @Test
    void checkBoundsTest() {
        MatrixFactory matrixFactory = getMatrixFactory();
        assertThrows(MatrixException.class, () -> matrixFactory.create(-1, 1, 1));
        assertThrows(MatrixException.class, () -> matrixFactory.create(1, -1, 1));

        Matrix a = matrixFactory.create(1, 1, 1);
        assertThrows(MatrixException.class, () -> a.set(-1, 0, 0));
        assertThrows(MatrixException.class, () -> a.set(0, -1, 0));
        assertThrows(MatrixException.class, () -> a.set(2, 0, 0));
        assertThrows(MatrixException.class, () -> a.set(0, 1, 0));
        assertThrows(MatrixException.class, () -> a.add(2, 0, 0));
        assertThrows(MatrixException.class, () -> a.add(0, 1, 0));
    }

    @Test
    void testMultiplication() {
        Matrix a = createA(getMatrixFactory());
        Matrix b = getMatrixFactory().create(2, 1, 2);
        b.set(0, 0, 4);
        b.set(1, 0, 5);

        Matrix cs = a.times(b);
        DenseMatrix c = cs.toDense();

        assertEquals(3, c.getRowCount());
        assertEquals(1, c.getColumnCount());
        assertEquals(4, c.get(0, 0), EPSILON);
        assertEquals(15, c.get(1, 0), EPSILON);
        assertEquals(8, c.get(2, 0), EPSILON);

        Matrix cs2 = a.times(b, 2);
        DenseMatrix c2 = cs2.toDense();

        assertEquals(3, c2.getRowCount());
        assertEquals(1, c2.getColumnCount());
        assertEquals(8, c2.get(0, 0), EPSILON);
        assertEquals(30, c2.get(1, 0), EPSILON);
        assertEquals(16, c2.get(2, 0), EPSILON);
    }

    @Test
    void testAddition() {
        /*
        1 0
        0 3
        2 0
         */
        Matrix a = createA(getMatrixFactory());
        /*
        4 0
        5 0
        0 0
         */
        Matrix b = getMatrixFactory().create(3, 2, 3);
        b.set(0, 0, 4);
        b.set(1, 0, 5);

        Matrix cs = a.add(b);
        DenseMatrix c = cs.toDense();

        assertEquals(3, c.getRowCount());
        assertEquals(2, c.getColumnCount());
        assertEquals(5, c.get(0, 0), EPSILON);
        assertEquals(5, c.get(1, 0), EPSILON);
        assertEquals(2, c.get(2, 0), EPSILON);
        assertEquals(0, c.get(0, 1), EPSILON);
        assertEquals(3, c.get(1, 1), EPSILON);
        assertEquals(0, c.get(2, 1), EPSILON);

        // in case of sparse matrix check, we only have 4 values
        if (cs instanceof SparseMatrix) {
            assertEquals(4, ((SparseMatrix) cs).getValues().length);
        }
    }

    @Test
    void testAdditionWithEmptyColumnInTheMiddle() {
        /*
        1 0 0
        0 0 3
        2 0 0
         */
        Matrix a = getMatrixFactory().create(3, 3, 3);
        a.set(0, 0, 1);
        a.set(2, 0, 2);
        a.set(1, 2, 3);

        /*
        4 0 6
        5 0 0
        0 0 0
         */
        Matrix b = getMatrixFactory().create(3, 3, 3);
        b.set(0, 0, 4);
        b.set(1, 0, 5);
        b.set(0, 2, 6);

        Matrix cs = a.add(b);
        DenseMatrix c = cs.toDense();

        assertEquals(3, c.getRowCount());
        assertEquals(3, c.getColumnCount());
        assertEquals(5, c.get(0, 0), EPSILON);
        assertEquals(5, c.get(1, 0), EPSILON);
        assertEquals(2, c.get(2, 0), EPSILON);
        assertEquals(0, c.get(0, 1), EPSILON);
        assertEquals(0, c.get(1, 1), EPSILON);
        assertEquals(0, c.get(2, 1), EPSILON);
        assertEquals(6, c.get(0, 2), EPSILON);
        assertEquals(3, c.get(1, 2), EPSILON);
        assertEquals(0, c.get(2, 2), EPSILON);
    }

    @Test
    void testIterateNonZeroValue() {
        Matrix a = createA(getMatrixFactory());
        a.iterateNonZeroValue((i, j, value) -> {
            if (i == 0 && j == 0) {
                assertEquals(1d, value, 0d);
            } else if (i == 1 && j == 1) {
                assertEquals(3d, value, 0d);
            } else if (i == 2 && j == 0) {
                assertEquals(2d, value, 0d);
            } else {
                fail();
            }
        });
    }

    @Test
    void testIterateNonZeroValueOfColumn() {
        Matrix a = createA(getMatrixFactory());
        List<Double> nonZeroValues = new ArrayList<>();
        a.iterateNonZeroValueOfColumn(0, (i, j, value) -> {
            nonZeroValues.add(value);
        });
        assertEquals(ImmutableList.of(1d, 2d), nonZeroValues);
    }

    protected String print(Matrix matrix, List<String> rowNames, List<String> columnNames) throws IOException {
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        try {
            matrix.print(new PrintStream(bos), rowNames, columnNames);
        } finally {
            bos.close();
        }
        return bos.toString(StandardCharsets.UTF_8.name());
    }

    protected String print(Matrix matrix) throws IOException {
        ByteArrayOutputStream bos = new ByteArrayOutputStream();
        try {
            matrix.print(new PrintStream(bos));
        } finally {
            bos.close();
        }
        return bos.toString(StandardCharsets.UTF_8.name());
    }

    @Test
    void testDecompose() {
        // 2  3  0  0  0
        // 3  0  4  0  6
        // 0 -1 -3  2  0
        // 0  0  1  0  0
        // 0  4  2  0  1
        Matrix matrix = getMatrixFactory().create(5, 5, 12);

        matrix.set(0, 0, 2);
        matrix.set(1, 0, 3);

        matrix.set(0, 1, 3);
        matrix.set(2, 1, -1);
        matrix.set(4, 1, 4);

        matrix.set(1, 2, 4);
        matrix.set(2, 2, -3);
        Matrix.Element e = matrix.addAndGetElement(3, 2, 1);
        matrix.set(4, 2, 2);

        matrix.set(2, 3, 2);

        matrix.set(1, 4, 6);
        matrix.set(4, 4, 1);

        try (LUDecomposition decomposition = matrix.decomposeLU()) {
            double[] x = {8, 45, -3, 3, 19};
            decomposition.solve(x);
            assertArrayEquals(new double[]{1, 2, 3, 4, 5}, x, EPSILON);

            DenseMatrix x2 = new DenseMatrix(5, 2);
            x2.set(0, 0, 8);
            x2.set(1, 0, 45);
            x2.set(2, 0, -3);
            x2.set(3, 0, 3);
            x2.set(4, 0, 19);
            x2.set(0, 1, 8);
            x2.set(1, 1, 45);
            x2.set(2, 1, -3);
            x2.set(3, 1, 3);
            x2.set(4, 1, 19);

            decomposition.solve(x2);
            assertEquals(1, x2.get(0, 0), EPSILON);
            assertEquals(2, x2.get(1, 0), EPSILON);
            assertEquals(3, x2.get(2, 0), EPSILON);
            assertEquals(4, x2.get(3, 0), EPSILON);
            assertEquals(5, x2.get(4, 0), EPSILON);
            assertEquals(1, x2.get(0, 1), EPSILON);
            assertEquals(2, x2.get(1, 1), EPSILON);
            assertEquals(3, x2.get(2, 1), EPSILON);
            assertEquals(4, x2.get(3, 1), EPSILON);
            assertEquals(5, x2.get(4, 1), EPSILON);

            e.set(4);
            e.add(1);
            decomposition.update();
            double[] x3 = {8, 45, -3, 3, 19};
            decomposition.solve(x3);
            assertArrayEquals(new double[]{-0.010526315789474902, 2.673684210526316, 0.6, 0.7368421052631579, 7.105263157894737}, x3, EPSILON);
        }
    }

    @Test
    void testDecompositionFailure() {
        Matrix matrix = getMatrixFactory().create(5, 5, 12);
        assertThrows(MatrixException.class, () -> {
            try (LUDecomposition decomposition = matrix.decomposeLU()) {
                double[] x = {0, 0, 0, 0, 0};
                decomposition.solve(x);
            }
        });
    }

    @Test
    void testTransposedDecompose() {
        // 2  3  0  0  0
        // 3  0 -1  0  4
        // 0  4 -3  1  2
        // 0  0  2  0  0
        // 0  6  0  0  1
        Matrix matrix = getMatrixFactory().create(5, 5, 12);

        matrix.set(0, 0, 2);
        matrix.set(1, 0, 3);

        matrix.set(0, 1, 3);
        matrix.set(2, 1, 4);
        matrix.set(4, 1, 6);

        matrix.set(1, 2, -1);
        matrix.set(2, 2, -3);
        matrix.set(3, 2, 2);

        matrix.set(2, 3, 1);

        matrix.set(1, 4, 4);
        matrix.set(2, 4, 2);
        matrix.set(4, 4, 1);

        try (LUDecomposition decomposition = matrix.decomposeLU()) {
            double[] x = {8, 45, -3, 3, 19};
            decomposition.solveTransposed(x);
            assertArrayEquals(new double[]{1, 2, 3, 4, 5}, x, EPSILON);

            DenseMatrix x2 = new DenseMatrix(5, 2);
            x2.set(0, 0, 8);
            x2.set(1, 0, 45);
            x2.set(2, 0, -3);
            x2.set(3, 0, 3);
            x2.set(4, 0, 19);
            x2.set(0, 1, 8);
            x2.set(1, 1, 45);
            x2.set(2, 1, -3);
            x2.set(3, 1, 3);
            x2.set(4, 1, 19);

            decomposition.solveTransposed(x2);
            assertEquals(1, x2.get(0, 0), EPSILON);
            assertEquals(2, x2.get(1, 0), EPSILON);
            assertEquals(3, x2.get(2, 0), EPSILON);
            assertEquals(4, x2.get(3, 0), EPSILON);
            assertEquals(5, x2.get(4, 0), EPSILON);
            assertEquals(1, x2.get(0, 1), EPSILON);
            assertEquals(2, x2.get(1, 1), EPSILON);
            assertEquals(3, x2.get(2, 1), EPSILON);
            assertEquals(4, x2.get(3, 1), EPSILON);
            assertEquals(5, x2.get(4, 1), EPSILON);
        }
    }

    private static void decomposeThenSolve(Matrix matrix, double[] b) {
        LUDecomposition luDecomposition = matrix.decomposeLU();
        luDecomposition.solve(b);
    }

    @Test
    void testDecomposeNonSquare() {
        Matrix matrix = getMatrixFactory().create(1, 2, 4);
        assertThrows(MatrixException.class, () -> decomposeThenSolve(matrix, new double[] {}));
    }

    @Test
    void testDenseEquals() {
        Matrix a1 = createA(getMatrixFactory());
        Matrix a2 = createA(getMatrixFactory());
        Matrix b1 = getMatrixFactory().create(5, 5, 0);
        Matrix b2 = getMatrixFactory().create(5, 5, 0);
        new EqualsTester()
                .addEqualityGroup(a1, a2)
                .addEqualityGroup(b1, b2)
                .testEquals();
    }

    @Test
    void toTest() {
        Matrix a = createA(getMatrixFactory());
        Matrix a2 = a.to(getOtherMatrixFactory());
        Matrix a3 = a2.to(getMatrixFactory());
        assertEquals(a, a3);
        assertSame(a.to(getMatrixFactory()), a);
        assertSame(a2.to(getOtherMatrixFactory()), a2);
    }

    @Test
    void testAddValue() {
        Matrix a = getMatrixFactory().create(2, 2, 2);
        a.add(0, 0, 1d);
        a.add(0, 0, 1d);
        a.add(1, 1, 1d);
        a.add(1, 1, 2d);

        DenseMatrix b = a.toDense();
        assertEquals(2d, b.get(0, 0), 0d);
        assertEquals(0d, b.get(1, 0), 0d);
        assertEquals(0d, b.get(0, 1), 0d);
        assertEquals(3d, b.get(1, 1), 0d);
    }

    @Test
    void testAddValue2() {
        Matrix a = getMatrixFactory().create(2, 2, 2);
        a.add(0, 0, 1d);
        a.add(0, 1, 1d);

        DenseMatrix b = a.toDense();
        assertEquals(1d, b.get(0, 0), 0d);
        assertEquals(0d, b.get(1, 0), 0d);
        assertEquals(1d, b.get(0, 1), 0d);
        assertEquals(0d, b.get(1, 1), 0d);
    }

    @Test
    void testIssueWithEmptyColumns() {
        Matrix a = getMatrixFactory().create(2, 2, 2);
        a.set(0, 0, 1d);
        // second column is empty
        assertEquals(1, a.toDense().get(0, 0), 0d);
    }

    @Test
    void testReset() {
        Matrix a = getMatrixFactory().create(3, 3, 3);
        // 1 0 4
        // 0 2 0
        // 0 3 0
        Matrix.Element e1 = a.addAndGetElement(0, 0, 1d);
        Matrix.Element e2 = a.addAndGetElement(1, 1, 2d);
        Matrix.Element e3 = a.addAndGetElement(2, 1, 3d);
        Matrix.Element e4 = a.addAndGetElement(0, 2, 4d);

        a.reset();

        assertEquals(0d, a.toDense().get(0, 0), 0d);
        assertEquals(0d, a.toDense().get(1, 1), 0d);
        assertEquals(0d, a.toDense().get(2, 1), 0d);
        assertEquals(0d, a.toDense().get(0, 2), 0d);

        e1.set(5d);
        e2.set(6d);
        e3.set(7d);
        e4.set(8d);

        assertEquals(5d, a.toDense().get(0, 0), 0d);
        assertEquals(6d, a.toDense().get(1, 1), 0d);
        assertEquals(7d, a.toDense().get(2, 1), 0d);
        assertEquals(8d, a.toDense().get(0, 2), 0d);
    }

    @Test
    void testDeprecated() {
        Matrix a = getMatrixFactory().create(2, 2, 2);
        assertEquals(a.getRowCount(), a.getM());
        assertEquals(a.getColumnCount(), a.getN());
        a.setValue(0, 0, 1d);
        a.setValue(0, 1, 1d);
        assertEquals(1d, a.toDense().get(0, 1), 0d);
        a.addValue(0, 1, 2d);
        assertEquals(3d, a.toDense().get(0, 1), 0d);
    }

    @Test
    void testAddAndGetIndex() {
        Matrix a = getMatrixFactory().create(3, 3, 3);
        // 1 0 4
        // 0 2 0
        // 0 3 0
        int index1 = a.addAndGetIndex(0, 0, 1d);
        int index2 = a.addAndGetIndex(1, 1, 2d);
        int index3 = a.addAndGetIndex(2, 1, 3d);
        int index4 = a.addAndGetIndex(0, 2, 4d);

        assertEquals(1d, a.toDense().get(0, 0), 0d);
        assertEquals(0d, a.toDense().get(1, 0), 0d);
        assertEquals(0d, a.toDense().get(2, 0), 0d);
        assertEquals(0d, a.toDense().get(0, 1), 0d);
        assertEquals(2d, a.toDense().get(1, 1), 0d);
        assertEquals(3d, a.toDense().get(2, 1), 0d);
        assertEquals(4d, a.toDense().get(0, 2), 0d);
        assertEquals(0d, a.toDense().get(1, 2), 0d);
        assertEquals(0d, a.toDense().get(2, 2), 0d);

        a.setAtIndex(index1, 9);
        a.addAtIndex(index2, 1);
        a.setAtIndex(index3, 10);
        a.addAtIndex(index4, 1);

        assertEquals(9d, a.toDense().get(0, 0), 0d);
        assertEquals(3d, a.toDense().get(1, 1), 0d);
        assertEquals(10d, a.toDense().get(2, 1), 0d);
        assertEquals(5d, a.toDense().get(0, 2), 0d);

        assertThrows(MatrixException.class, () -> a.setAtIndex(10, 0));
    }

    @Test
    void testTranspose() {
        Matrix a = createA(getMatrixFactory());
        DenseMatrix at = a.transpose().toDense();
        assertEquals(2, at.getRowCount());
        assertEquals(3, at.getColumnCount());
        assertEquals(1d, at.get(0, 0), 0d);
        assertEquals(0d, at.get(0, 1), 0d);
        assertEquals(2d, at.get(0, 2), 0d);
        assertEquals(0d, at.get(1, 0), 0d);
        assertEquals(3d, at.get(1, 1), 0d);
        assertEquals(0d, at.get(1, 2), 0d);
    }
}