/*
 * Copyright (c) 2022, Peter Abeles. All Rights Reserved.
 *
 * This file is part of Efficient Java Matrix Library (EJML).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.ejml.sparse.csc.misc;

import javax.annotation.Generated;
import org.ejml.UtilEjml;
import org.ejml.data.FGrowArray;
import org.ejml.data.FMatrixSparseCSC;
import org.ejml.data.IGrowArray;
import org.ejml.ops.IPredicateBinary;
import org.ejml.sparse.csc.CommonOps_FSCC;
import org.jetbrains.annotations.Nullable;

import java.util.Arrays;

import static org.ejml.UtilEjml.adjust;
import static org.ejml.sparse.csc.mult.ImplMultiplication_FSCC.multAddColA;

/**
 * Implementation class. Not recommended for direct use. Instead use {@link CommonOps_FSCC}
 * instead.
 *
 * @author Peter Abeles
 */
@Generated("org.ejml.sparse.csc.misc.ImplCommonOps_DSCC")
public class ImplCommonOps_FSCC {

    public static void select( FMatrixSparseCSC A, FMatrixSparseCSC output, IPredicateBinary selector ) {
        int selectCount = 0;

        // size estimation
        if (output != A) {
            output.growMaxLength(A.nz_length/2, false);
        }

        // selecting a subset doesn't change the order
        output.indicesSorted = A.indicesSorted;

        for (int col = 0; col < A.numCols; col++) {
            int start = A.col_idx[col];
            int end = A.col_idx[col + 1];

            output.col_idx[col] = selectCount;

            if (output.nz_rows.length < (selectCount + (end - start))) {
                int maxLength = Integer.max(output.nz_length*2 + 1, A.nz_length);
                output.growMaxLength(maxLength, true);
            }

            for (int i = start; i < end; i++) {
                int row = A.nz_rows[i];

                if (selector.apply(row, col)) {
                    output.nz_rows[selectCount] = row;
                    output.nz_values[selectCount] = A.nz_values[i];
                    selectCount++;
                }
            }
        }
        // writing last entry
        output.col_idx[output.numCols] = selectCount;

        output.nz_length = selectCount;
    }

    /**
     * Performs a matrix transpose.
     *
     * @param A Original matrix. Not modified.
     * @param C Storage for transposed 'A'. Reshaped.
     * @param gw (Optional) Storage for internal workspace. Can be null.
     */
    public static void transpose( FMatrixSparseCSC A, FMatrixSparseCSC C, @Nullable IGrowArray gw ) {
        int[] work = adjust(gw, A.numRows, A.numRows);
        C.reshape(A.numCols, A.numRows, A.nz_length);

        // compute the histogram for each row in 'a'
        for (int j = 0; j < A.nz_length; j++) {
            work[A.nz_rows[j]]++;
        }

        // construct col_idx in the transposed matrix
        C.histogramToStructure(work);
        System.arraycopy(C.col_idx, 0, work, 0, C.numCols);

        // fill in the row indexes
        int idx0 = A.col_idx[0];
        for (int j = 1; j <= A.numCols; j++) {
            final int col = j - 1;
            final int idx1 = A.col_idx[j];
            for (int i = idx0; i < idx1; i++) {
                int row = A.nz_rows[i];
                int index = work[row]++;
                C.nz_rows[index] = col;
                C.nz_values[index] = A.nz_values[i];
            }
            idx0 = idx1;
        }
    }

    /**
     * Performs matrix addition:<br>
     * C = &alpha;A + &beta;B
     *
     * @param alpha scalar value multiplied against A
     * @param A Matrix
     * @param beta scalar value multiplied against B
     * @param B Matrix
     * @param C Output matrix.
     * @param gw (Optional) Storage for internal workspace. Can be null.
     * @param gx (Optional) Storage for internal workspace. Can be null.
     */
    public static void add( float alpha, FMatrixSparseCSC A, float beta, FMatrixSparseCSC B, FMatrixSparseCSC C,
                            @Nullable IGrowArray gw, @Nullable FGrowArray gx ) {
        float[] x = adjust(gx, A.numRows);
        int[] w = adjust(gw, A.numRows, A.numRows);

        C.indicesSorted = false;
        C.nz_length = 0;

        for (int col = 0; col < A.numCols; col++) {
            C.col_idx[col] = C.nz_length;

            multAddColA(A, col, alpha, C, col + 1, x, w);
            multAddColA(B, col, beta, C, col + 1, x, w);

            // take the values in the dense vector 'x' and put them into 'C'
            int idxC0 = C.col_idx[col];
            int idxC1 = C.col_idx[col + 1];

            for (int i = idxC0; i < idxC1; i++) {
                C.nz_values[i] = x[C.nz_rows[i]];
            }
        }
        C.col_idx[A.numCols] = C.nz_length;
    }

    /**
     * Adds the results of adding a column in A and B as a new column in C.<br>
     * C(:,end+1) = &alpha;*A(:,colA) + &beta;*B(:,colB)
     *
     * @param alpha scalar
     * @param A matrix
     * @param colA column in A
     * @param beta scalar
     * @param B matrix
     * @param colB column in B
     * @param C Column in C
     * @param gw workspace
     */
    public static void addColAppend( float alpha, FMatrixSparseCSC A, int colA, float beta, FMatrixSparseCSC B, int colB,
                                     FMatrixSparseCSC C, @Nullable IGrowArray gw ) {
        if (A.numRows != B.numRows || A.numRows != C.numRows)
            throw new IllegalArgumentException("Number of rows in A, B, and C do not match");

        int idxA0 = A.col_idx[colA];
        int idxA1 = A.col_idx[colA + 1];
        int idxB0 = B.col_idx[colB];
        int idxB1 = B.col_idx[colB + 1];

        C.growMaxColumns(++C.numCols, true);
        C.growMaxLength(C.nz_length + idxA1 - idxA0 + idxB1 - idxB0, true);

        int[] w = adjust(gw, A.numRows);
        Arrays.fill(w, 0, A.numRows, -1);

        for (int i = idxA0; i < idxA1; i++) {
            int row = A.nz_rows[i];
            C.nz_rows[C.nz_length] = row;
            C.nz_values[C.nz_length] = alpha*A.nz_values[i];
            w[row] = C.nz_length++;
        }

        for (int i = idxB0; i < idxB1; i++) {
            int row = B.nz_rows[i];
            if (w[row] != -1) {
                C.nz_values[w[row]] += beta*B.nz_values[i];
            } else {
                C.nz_values[C.nz_length] = beta*B.nz_values[i];
                C.nz_rows[C.nz_length++] = row;
            }
        }
        C.col_idx[C.numCols] = C.nz_length;
    }

    /**
     * Performs element-wise multiplication:<br>
     * C_ij = A_ij * B_ij
     *
     * @param A (Input) Matrix
     * @param B (Input) Matrix
     * @param C (Output) Matrix.
     * @param gw (Optional) Storage for internal workspace. Can be null.
     * @param gx (Optional) Storage for internal workspace. Can be null.
     */
    public static void elementMult( FMatrixSparseCSC A, FMatrixSparseCSC B, FMatrixSparseCSC C,
                                    @Nullable IGrowArray gw, @Nullable FGrowArray gx ) {
        float[] x = adjust(gx, A.numRows);
        int[] w = adjust(gw, A.numRows);
        Arrays.fill(w, 0, A.numRows, -1); // fill with -1. This will be a value less than column

        C.growMaxLength(Math.min(A.nz_length, B.nz_length), false);
        C.indicesSorted = false; // Hmm I think if B is storted then C will be sorted...
        C.nz_length = 0;

        for (int col = 0; col < A.numCols; col++) {
            int idxA0 = A.col_idx[col];
            int idxA1 = A.col_idx[col + 1];
            int idxB0 = B.col_idx[col];
            int idxB1 = B.col_idx[col + 1];

            // compute the maximum number of elements that there can be in this row
            int maxInRow = Math.min(idxA1 - idxA0, idxB1 - idxB0);

            // make sure there are enough non-zero elements in C
            if (C.nz_length + maxInRow > C.nz_values.length)
                C.growMaxLength(C.nz_values.length + maxInRow, true);

            // update the structure of C
            C.col_idx[col] = C.nz_length;

            // mark the rows that appear in A and save their value
            for (int i = idxA0; i < idxA1; i++) {
                int row = A.nz_rows[i];
                w[row] = col;
                x[row] = A.nz_values[i];
            }

            // If a row appears in A and B, multiply and set as an element in C
            for (int i = idxB0; i < idxB1; i++) {
                int row = B.nz_rows[i];
                if (w[row] == col) {
                    C.nz_values[C.nz_length] = x[row]*B.nz_values[i];
                    C.nz_rows[C.nz_length++] = row;
                }
            }
        }
        C.col_idx[C.numCols] = C.nz_length;
    }

    public static void removeZeros( FMatrixSparseCSC input, FMatrixSparseCSC output, float tol ) {
        output.reshape(input.numRows, input.numCols, input.nz_length);
        output.nz_length = 0;

        for (int i = 0; i < input.numCols; i++) {
            output.col_idx[i] = output.nz_length;

            int idx0 = input.col_idx[i];
            int idx1 = input.col_idx[i + 1];

            for (int j = idx0; j < idx1; j++) {
                float val = input.nz_values[j];
                if (Math.abs(val) > tol) {
                    output.nz_rows[output.nz_length] = input.nz_rows[j];
                    output.nz_values[output.nz_length++] = val;
                }
            }
        }
        output.col_idx[output.numCols] = output.nz_length;
    }

    public static void removeZeros( FMatrixSparseCSC A, float tol ) {

        int offset = 0;
        for (int i = 0; i < A.numCols; i++) {
            int idx0 = A.col_idx[i] + offset;
            int idx1 = A.col_idx[i + 1];

            for (int j = idx0; j < idx1; j++) {
                float val = A.nz_values[j];
                if (Math.abs(val) > tol) {
                    A.nz_rows[j - offset] = A.nz_rows[j];
                    A.nz_values[j - offset] = val;
                } else {
                    offset++;
                }
            }
            A.col_idx[i + 1] -= offset;
        }
        A.nz_length -= offset;
    }

    public static void duplicatesAdd( FMatrixSparseCSC A, @Nullable IGrowArray work ) {
        // Look up table from row to nz index
        int[] table = UtilEjml.adjustFill(work, A.numRows, -1);

        int offset = 0;
        for (int i = 0; i < A.numCols; i++) {
            int idx0 = A.col_idx[i] + offset;
            int idx1 = A.col_idx[i + 1];

            // When a row is first encountered note the element it's at
            for (int j = idx0; j < idx1; j++) {
                int row = A.nz_rows[j];
                if (table[row] == -1)
                    table[row] = j;
            }

            // Set then add each element
            for (int j = idx0; j < idx1; j++) {
                int row = A.nz_rows[j];

                // First or only time it's encountered, copy the value
                if (table[row] == j) {
                    A.nz_rows[j - offset] = row;
                    A.nz_values[j - offset] = A.nz_values[j];
                    table[row] = j - offset; // Update the table to include the offset location
                } else {
                    // Each time it's encountered after this add the value and increase the offset
                    A.nz_values[table[row]] += A.nz_values[j];
                    offset++;
                }
            }
            A.col_idx[i + 1] -= offset;

            // Need to do a second pass to undo the markings in the lookup table
            idx1 -= offset;
            for (int j = A.col_idx[i]; j < idx1; j++) {
                table[A.nz_rows[j]] = -1;
            }
        }
        A.nz_length -= offset;
    }

    /**
     * Given a symmetric matrix which is represented by a lower triangular matrix convert it back into
     * a full symmetric matrix
     *
     * @param A (Input) Lower triangular matrix
     * @param B (Output) Symmetric matrix.
     * @param gw (Optional) Workspace. Can be null.
     */
    public static void symmLowerToFull( FMatrixSparseCSC A, FMatrixSparseCSC B, @Nullable IGrowArray gw ) {
        if (A.numCols != A.numRows)
            throw new IllegalArgumentException("Must be a lower triangular square matrix");

        int N = A.numCols;
        int[] w = adjust(gw, N, N);
        B.reshape(N, N, A.nz_length*2);
        B.indicesSorted = false;

        //=== determine the row counts of the full matrix
        for (int col = 0; col < N; col++) {
            int idx0 = A.col_idx[col];
            int idx1 = A.col_idx[col + 1];

            // We know the length of the lower part of this column already
            w[col] += idx1 - idx0;

            // add elements to the top of the other columns along row with index 'col'
            for (int i = idx0; i < idx1; i++) {
                int row = A.nz_rows[i];
                if (row > col) {
                    w[row]++;
                }
            }
        }

        // Update the structure of B
        B.histogramToStructure(w);

        // Zero W again. It's being used to keep track of how many elements have been added to a column already
        Arrays.fill(w, 0, N, 0);
        // Fill in matrix
        for (int col = 0; col < N; col++) {

            int idx0 = A.col_idx[col];
            int idx1 = A.col_idx[col + 1];

            int lengthA = idx1 - idx0;
            int lengthB = B.col_idx[col + 1] - B.col_idx[col];

            // Copy the non-zero values from A into B along the columns while taking in account the upper
            // elements already copied
            System.arraycopy(A.nz_values, idx0, B.nz_values, B.col_idx[col] + lengthB - lengthA, lengthA);
            System.arraycopy(A.nz_rows, idx0, B.nz_rows, B.col_idx[col] + lengthB - lengthA, lengthA);

            // Copy this column into the upper portion of B
            for (int i = idx0; i < idx1; i++) {
                int row = A.nz_rows[i];
                if (row > col) {
                    int indexB = B.col_idx[row] + w[row]++;
                    B.nz_rows[indexB] = col;
                    B.nz_values[indexB] = A.nz_values[i];
                }
            }
        }
    }
}
