/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.optimization;

import gnu.trove.iterator.TIntIterator;
import gnu.trove.list.array.TIntArrayList;
import java.util.Random;
import java.util.function.Function;
import org.ejml.MatrixDimensionException;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commons.Conversions;
import us.ihmc.commons.RandomNumbers;
import us.ihmc.euclid.transform.RigidBodyTransform;
import us.ihmc.euclid.transform.interfaces.RigidBodyTransformReadOnly;
import us.ihmc.matrixlib.MatrixTools;
import us.ihmc.robotics.optimization.OutputCalculator;

public class LevenbergMarquardtParameterOptimizer {
    private static final boolean DEBUG = false;
    private int inputDimension;
    private final Function<DMatrixRMaj, RigidBodyTransform> inputFunction;
    private final OutputCalculator outputCalculator;
    private final DMatrixRMaj currentInput;
    private final OutputSpace currentOutputSpace;
    private final DMatrixRMaj perturbationVector;
    private final DMatrixRMaj perturbedInput;
    private final DMatrixRMaj jacobian;
    private final DMatrixRMaj squaredJacobian;
    private final DMatrixRMaj dampingMatrix;
    private final DMatrixRMaj invMultJacobianTranspose;
    private final DMatrixRMaj optimizeDirection;
    private double correspondenceThreshold = 1.0;
    private static final double DEFAULT_PERTURBATION = 1.0E-4;
    private static final double DEFAULT_DAMPING_COEFFICIENT = 0.001;
    private boolean useDamping = true;
    private int maximumNumberOfCorrespondences = Integer.MAX_VALUE;
    private int iteration;
    private int numberOfCorrespondences;
    private boolean optimized;

    public LevenbergMarquardtParameterOptimizer(Function<DMatrixRMaj, RigidBodyTransform> inputFunction, OutputCalculator outputCalculator, int inputParameterDimension, int outputDimension) {
        this.inputFunction = inputFunction;
        this.inputDimension = inputParameterDimension;
        this.outputCalculator = outputCalculator;
        this.currentInput = new DMatrixRMaj(inputParameterDimension, 1);
        this.currentOutputSpace = new OutputSpace(outputDimension);
        this.perturbationVector = new DMatrixRMaj(inputParameterDimension, 1);
        CommonOps_DDRM.fill((DMatrixD1)this.perturbationVector, (double)1.0E-4);
        this.perturbedInput = new DMatrixRMaj(inputParameterDimension, 1);
        this.jacobian = new DMatrixRMaj(outputDimension, inputParameterDimension);
        this.squaredJacobian = new DMatrixRMaj(inputParameterDimension, inputParameterDimension);
        this.dampingMatrix = new DMatrixRMaj(inputParameterDimension, inputParameterDimension);
        this.invMultJacobianTranspose = new DMatrixRMaj(inputParameterDimension, outputDimension);
        this.optimizeDirection = new DMatrixRMaj(inputParameterDimension, 1);
    }

    public void setPerturbationVector(DMatrixRMaj perturbationVector) {
        if (this.perturbationVector.getNumCols() != perturbationVector.getNumCols()) {
            throw new MatrixDimensionException("dimension is wrong. " + this.perturbationVector.getNumCols() + " " + perturbationVector.getNumCols());
        }
        this.perturbationVector.set((DMatrixD1)perturbationVector);
    }

    public void setCorrespondenceThreshold(double correspondenceThreshold) {
        this.correspondenceThreshold = correspondenceThreshold;
    }

    public void setMaximumNumberOfCorrespondences(int maximumNumberOfCorrespondences) {
        this.maximumNumberOfCorrespondences = maximumNumberOfCorrespondences;
    }

    public void setInitialOptimalGuess(DMatrixRMaj initialOptimalGuess) {
        this.currentInput.set((DMatrixD1)initialOptimalGuess);
    }

    public boolean initialize() {
        this.iteration = 0;
        this.optimized = false;
        MatrixTools.setDiagonal((DMatrix1Row)this.dampingMatrix, (double)0.001);
        this.outputCalculator.resetIndicesToCompute();
        this.currentOutputSpace.updateOutputSpace((DMatrixRMaj)this.outputCalculator.apply(this.currentInput));
        boolean result = this.currentOutputSpace.computeCorrespondence();
        this.currentOutputSpace.computeQuality();
        return result;
    }

    public double iterate() {
        ++this.iteration;
        long startTime = System.nanoTime();
        if (this.currentOutputSpace.getNumberOfCorrespondingPoints() < 1) {
            return -1.0;
        }
        this.outputCalculator.setIndicesToCompute(this.currentOutputSpace.correspondingIndices);
        this.numberOfCorrespondences = this.currentOutputSpace.getNumberOfCorrespondingPoints();
        this.jacobian.reshape(this.numberOfCorrespondences, this.inputDimension);
        this.invMultJacobianTranspose.reshape(this.inputDimension, this.numberOfCorrespondences);
        this.perturbedInput.set((DMatrixD1)this.currentInput);
        for (int i = 0; i < this.inputDimension; ++i) {
            this.perturbedInput.add(i, 0, this.perturbationVector.get(i));
            DMatrixRMaj perturbedOutput = (DMatrixRMaj)this.outputCalculator.apply(this.perturbedInput);
            DMatrixRMaj currentOutput = this.currentOutputSpace.getCorrespondingOutput();
            for (int j = 0; j < this.numberOfCorrespondences; ++j) {
                double partialValue = (perturbedOutput.get(j) - currentOutput.get(j)) / this.perturbationVector.get(i);
                this.jacobian.set(j, i, partialValue);
            }
            this.perturbedInput.add(i, 0, -this.perturbationVector.get(i));
        }
        CommonOps_DDRM.multInner((DMatrix1Row)this.jacobian, (DMatrix1Row)this.squaredJacobian);
        if (this.useDamping) {
            CommonOps_DDRM.addEquals((DMatrixD1)this.squaredJacobian, (DMatrixD1)this.dampingMatrix);
        }
        CommonOps_DDRM.invert((DMatrixRMaj)this.squaredJacobian);
        CommonOps_DDRM.multTransB((DMatrix1Row)this.squaredJacobian, (DMatrix1Row)this.jacobian, (DMatrix1Row)this.invMultJacobianTranspose);
        CommonOps_DDRM.mult((DMatrix1Row)this.invMultJacobianTranspose, (DMatrix1Row)this.currentOutputSpace.getCorrespondingOutput(), (DMatrix1Row)this.optimizeDirection);
        CommonOps_DDRM.subtractEquals((DMatrixD1)this.currentInput, (DMatrixD1)this.optimizeDirection);
        double iterateTime = Conversions.nanosecondsToSeconds((long)(System.nanoTime() - startTime));
        this.outputCalculator.resetIndicesToCompute();
        this.currentOutputSpace.updateOutputSpace((DMatrixRMaj)this.outputCalculator.apply(this.currentInput));
        this.currentOutputSpace.computeCorrespondence();
        this.currentOutputSpace.computeQuality();
        return this.currentOutputSpace.getCorrespondingQuality();
    }

    public void convertInputToTransform(DMatrixRMaj input, RigidBodyTransform transformToPack) {
        if (input.getData().length != this.inputDimension) {
            throw new MatrixDimensionException("dimension is wrong. " + input.getData().length + " " + this.inputDimension);
        }
        transformToPack.set(this.inputFunction.apply(input));
    }

    public int getNumberOfCorrespondingPoints() {
        return this.numberOfCorrespondences;
    }

    public DMatrixRMaj getOptimalParameter() {
        return this.currentInput;
    }

    public boolean isSolved() {
        return this.optimized;
    }

    public double getQuality() {
        return this.currentOutputSpace.getCorrespondingQuality();
    }

    public double getPureQuality() {
        return this.currentOutputSpace.getQuality();
    }

    public int getIteration() {
        return this.iteration;
    }

    public static Function<DMatrixRMaj, RigidBodyTransform> createSpatialInputFunction(final boolean includePitchAndRoll) {
        return new Function<DMatrixRMaj, RigidBodyTransform>(){

            @Override
            public RigidBodyTransform apply(DMatrixRMaj input) {
                RigidBodyTransform transform = new RigidBodyTransform();
                if (includePitchAndRoll) {
                    transform.setRotationYawPitchRollAndZeroTranslation(input.get(5), input.get(4), input.get(3));
                } else {
                    transform.setRotationYawAndZeroTranslation(input.get(3));
                }
                transform.getTranslation().set(input.get(0), input.get(1), input.get(2));
                return transform;
            }
        };
    }

    public static Function<RigidBodyTransformReadOnly, DMatrixRMaj> createInverseSpatialInputFunction(boolean includePitchAndRoll) {
        return transform -> {
            int size = includePitchAndRoll ? 6 : 4;
            DMatrixRMaj input = new DMatrixRMaj(size, 1);
            if (includePitchAndRoll) {
                input.set(3, transform.getRotation().getRoll());
                input.set(4, transform.getRotation().getPitch());
                input.set(5, transform.getRotation().getYaw());
            } else {
                input.set(3, transform.getRotation().getYaw());
            }
            transform.getTranslation().get((DMatrix)input);
            return input;
        };
    }

    private class OutputSpace {
        private final DMatrixRMaj output;
        private DMatrixRMaj correspondingOutput;
        private final boolean[] correspondence;
        private final TIntArrayList correspondingIndices = new TIntArrayList();
        private double correspondingQuality;
        private double quality;

        private OutputSpace(int dimension) {
            this.output = new DMatrixRMaj(dimension, 1);
            this.correspondence = new boolean[dimension];
        }

        void updateOutputSpace(DMatrixRMaj output) {
            this.output.set((DMatrixD1)output);
        }

        boolean computeCorrespondence() {
            this.correspondingIndices.clear();
            for (int i = 0; i < this.output.getNumRows(); ++i) {
                if (this.output.get(i, 0) < LevenbergMarquardtParameterOptimizer.this.correspondenceThreshold) {
                    this.correspondence[i] = true;
                    this.correspondingIndices.add(i);
                    continue;
                }
                this.correspondence[i] = false;
            }
            this.randomlySampleCorrespondences(this.correspondingIndices, LevenbergMarquardtParameterOptimizer.this.maximumNumberOfCorrespondences);
            this.correspondingOutput = new DMatrixRMaj(this.correspondingIndices.size(), 1);
            int index = 0;
            TIntIterator iterator = this.correspondingIndices.iterator();
            while (iterator.hasNext()) {
                this.correspondingOutput.set(index++, 0, this.output.get(iterator.next()));
            }
            return this.correspondingIndices.size() != 0;
        }

        private void randomlySampleCorrespondences(TIntArrayList correpsondencesToSample, int maxNumberOfCorrespondences) {
            Random random = new Random();
            while (correpsondencesToSample.size() > maxNumberOfCorrespondences) {
                correpsondencesToSample.remove(RandomNumbers.nextInt((Random)random, (int)0, (int)(correpsondencesToSample.size() - 1)));
            }
        }

        void computeQuality() {
            this.correspondingQuality = 0.0;
            this.quality = 0.0;
            for (int i = 0; i < this.output.getNumRows(); ++i) {
                double norm = this.output.get(i, 0) * this.output.get(i, 0);
                this.quality += norm;
                if (!this.correspondence[i]) continue;
                this.correspondingQuality += norm;
            }
        }

        DMatrixRMaj getOutput() {
            return this.output;
        }

        DMatrixRMaj getCorrespondingOutput() {
            return this.correspondingOutput;
        }

        int getNumberOfCorrespondingPoints() {
            return this.correspondingIndices.size();
        }

        double getCorrespondingQuality() {
            return this.correspondingQuality;
        }

        double getQuality() {
            return this.quality;
        }
    }
}

