/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.convexOptimization.quadraticProgram;

import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixRMaj;
import org.ejml.data.Matrix;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.factory.LinearSolverFactory_DDRM;
import org.ejml.interfaces.linsol.LinearSolverDense;

public class BlockDiagSquareMatrix
extends DMatrixRMaj {
    private static final long serialVersionUID = 8813856249678942997L;
    int[] blockSizes;
    int[] blockStarts;
    DMatrixRMaj[] tmpMatrix;
    DMatrixRMaj multTempB = new DMatrixRMaj(0);
    DMatrixRMaj multTempC = new DMatrixRMaj(0);

    public BlockDiagSquareMatrix(int ... blockSizes) {
        super(0);
        this.blockSizes = blockSizes;
        this.blockStarts = new int[this.getNumBlocks() + 1];
        this.tmpMatrix = new DMatrixRMaj[this.getNumBlocks()];
        int matrixRows = 0;
        for (int i = 0; i < this.getNumBlocks(); ++i) {
            this.tmpMatrix[i] = new DMatrixRMaj(blockSizes[i], blockSizes[i]);
            this.blockStarts[i] = matrixRows;
            matrixRows += blockSizes[i];
        }
        this.blockStarts[this.blockStarts.length - 1] = matrixRows;
        super.reshape(matrixRows, matrixRows);
    }

    public int getNumBlocks() {
        return this.blockSizes.length;
    }

    public void setBlock(DMatrixRMaj srcBlock, int blockId) {
        this.setBlock(srcBlock, blockId, this);
    }

    public void setBlock(DMatrixRMaj srcBlock, int blockId, DMatrixRMaj dstMatrix) {
        dstMatrix.reshape(this.numRows, this.numCols);
        int startIndex = this.blockStarts[blockId];
        CommonOps_DDRM.insert((DMatrix)srcBlock, (DMatrix)dstMatrix, (int)startIndex, (int)startIndex);
    }

    public void packBlock(DMatrixRMaj dstBlock, int blockId, int destX0, int destY0) {
        int startIndex = this.blockStarts[blockId];
        int endIndex = this.blockStarts[blockId + 1];
        CommonOps_DDRM.extract((DMatrix)this, (int)startIndex, (int)endIndex, (int)startIndex, (int)endIndex, (DMatrix)dstBlock, (int)destX0, (int)destY0);
    }

    public void packInverse(LinearSolverDense<DMatrixRMaj> solver, BlockDiagSquareMatrix matrixToPack) {
        for (int i = 0; i < this.blockSizes.length; ++i) {
            this.tmpMatrix[i].reshape(this.blockSizes[i], this.blockSizes[i]);
            this.packBlock(this.tmpMatrix[i], i, 0, 0);
            solver.setA((Matrix)this.tmpMatrix[i]);
            solver.invert((Matrix)this.tmpMatrix[i]);
            matrixToPack.setBlock(this.tmpMatrix[i], i);
        }
    }

    public void packInverse(LinearSolverDense<DMatrixRMaj> solver, DMatrixRMaj matrixToPack) {
        matrixToPack.zero();
        for (int i = 0; i < this.blockSizes.length; ++i) {
            this.tmpMatrix[i].reshape(this.blockSizes[i], this.blockSizes[i]);
            this.packBlock(this.tmpMatrix[i], i, 0, 0);
            solver.setA((Matrix)this.tmpMatrix[i]);
            solver.invert((Matrix)this.tmpMatrix[i]);
            this.setBlock(this.tmpMatrix[i], i, matrixToPack);
        }
    }

    public void multTransB(DMatrixRMaj b, DMatrixRMaj c) {
        for (int i = 0; i < this.blockSizes.length; ++i) {
            for (int crow = this.blockStarts[i]; crow < this.blockStarts[i + 1]; ++crow) {
                int aIndex0 = this.getIndex(crow, this.blockStarts[i]);
                for (int ccol = 0; ccol < c.numCols; ++ccol) {
                    double val = 0.0;
                    int aIndex = aIndex0;
                    int bIndex = b.getIndex(ccol, this.blockStarts[i]);
                    int bEnd = bIndex + this.blockSizes[i];
                    while (bIndex < bEnd) {
                        val += this.data[aIndex++] * b.data[bIndex++];
                    }
                    c.set(crow, ccol, val);
                }
            }
        }
    }

    public void mult(double alpha, DMatrixRMaj b, DMatrixRMaj c) {
        for (int i = 0; i < this.blockSizes.length; ++i) {
            this.tmpMatrix[i].reshape(this.blockSizes[i], this.blockSizes[i]);
            this.packBlock(this.tmpMatrix[i], i, 0, 0);
            this.multTempB.reshape(this.blockSizes[i], b.numCols);
            this.multTempC.reshape(this.blockSizes[i], c.numCols);
            CommonOps_DDRM.extract((DMatrix)b, (int)this.blockStarts[i], (int)this.blockStarts[i + 1], (int)0, (int)b.numCols, (DMatrix)this.multTempB, (int)0, (int)0);
            CommonOps_DDRM.mult((double)alpha, (DMatrix1Row)this.tmpMatrix[i], (DMatrix1Row)this.multTempB, (DMatrix1Row)this.multTempC);
            CommonOps_DDRM.insert((DMatrix)this.multTempC, (DMatrix)c, (int)this.blockStarts[i], (int)0);
        }
    }

    public static void main(String[] arg) {
        BlockDiagSquareMatrix m = new BlockDiagSquareMatrix(1, 2);
        DMatrixRMaj b1 = new DMatrixRMaj(1, 1, true, new double[]{1.0});
        DMatrixRMaj b2 = new DMatrixRMaj(2, 2, true, new double[]{2.0, 3.0, 4.0, 5.0});
        m.setBlock(b1, 0);
        m.setBlock(b2, 1);
        System.out.println((Object)m);
        m.packInverse((LinearSolverDense<DMatrixRMaj>)LinearSolverFactory_DDRM.general((int)m.numRows, (int)m.numCols), m);
        b1.zero();
        b2.zero();
        m.packBlock(b1, 0, 0, 0);
        m.packBlock(b2, 1, 0, 0);
        System.out.println(b1);
        System.out.println(b2);
        System.out.println("m=\n" + m);
    }
}

