/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.finitedifferences;

import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.function.Function;
import org.nd4j.linalg.ops.transforms.Transforms;

public class TwoPointApproximation {
    public static INDArray[] prepareBounds(INDArray bounds, INDArray x) {
        return new INDArray[]{Nd4j.valueArrayOf(x.shape(), bounds.getDouble(0L)), Nd4j.valueArrayOf(x.shape(), bounds.getDouble(1L))};
    }

    public static INDArray[] adjustSchemeToBounds(INDArray x, INDArray h, int numSteps, INDArray lowerBound, INDArray upperBound) {
        INDArray oneSided = Nd4j.onesLike(h);
        if (Transforms.and(lowerBound.eq(Double.NEGATIVE_INFINITY), upperBound.eq(Double.POSITIVE_INFINITY)).sumNumber().doubleValue() > 0.0) {
            return new INDArray[]{h, oneSided};
        }
        INDArray hTotal = h.mul(numSteps);
        INDArray hAdjusted = h.dup();
        INDArray lowerDist = x.sub(lowerBound);
        INDArray upperBound2 = upperBound.sub(x);
        INDArray central = Transforms.and(Transforms.greaterThanOrEqual(lowerDist, hTotal), Transforms.greaterThanOrEqual(upperBound2, hTotal));
        INDArray forward = Transforms.and(Transforms.greaterThanOrEqual(upperBound, lowerDist), Transforms.not(central));
        hAdjusted.put(forward, Transforms.min(h.get(forward), upperBound2.get(forward).mul(0.5).divi(numSteps)));
        oneSided.put(forward, Nd4j.scalar(1.0));
        INDArray backward = Transforms.and(upperBound2.lt(lowerBound), Transforms.not(central));
        hAdjusted.put(backward, Transforms.min(h.get(backward), lowerDist.get(backward).mul(0.5).divi(numSteps)));
        oneSided.put(backward, Nd4j.scalar(1.0));
        INDArray minDist = Transforms.min(upperBound2, lowerDist).divi(numSteps);
        INDArray adjustedCentral = Transforms.and(Transforms.not(central), Transforms.lessThanOrEqual(Transforms.abs(hAdjusted), minDist));
        hAdjusted.put(adjustedCentral, minDist.get(adjustedCentral));
        oneSided.put(adjustedCentral, Nd4j.scalar(0.0));
        return new INDArray[]{hAdjusted, oneSided};
    }

    public static INDArray computeAbsoluteStep(INDArray x) {
        INDArray relStep = Transforms.pow(Nd4j.scalar(Nd4j.EPS_THRESHOLD), 0.5);
        return TwoPointApproximation.computeAbsoluteStep(relStep, x);
    }

    public static double getEpsRelativeTo(INDArray data) {
        if (data.data().dataType() == DataBuffer.Type.FLOAT) {
            return 1.1920929E-7;
        }
        return 2.220446049250313E-16;
    }

    public static INDArray computeAbsoluteStep(INDArray relStep, INDArray x) {
        if (relStep == null) {
            relStep = Transforms.pow(Nd4j.scalar(TwoPointApproximation.getEpsRelativeTo(x)), 0.5);
        }
        INDArray signX0 = x.gte(0).muli(2).subi(1);
        return signX0.mul(relStep).muli(Transforms.max(Transforms.abs(x), 1.0));
    }

    public static INDArray approximateDerivative(Function<INDArray, INDArray> f, INDArray x, INDArray relStep, INDArray f0, INDArray bounds) {
        if (x.rank() > 2) {
            throw new ND4JIllegalArgumentException("Argument must be a vector or scalar");
        }
        INDArray h = TwoPointApproximation.computeAbsoluteStep(relStep, x);
        INDArray[] upperAndLower = TwoPointApproximation.prepareBounds(bounds, x);
        INDArray[] boundaries = TwoPointApproximation.adjustSchemeToBounds(x, h, 1, upperAndLower[0], upperAndLower[1]);
        return TwoPointApproximation.denseDifference(f, x, f0, h, boundaries[1]);
    }

    public static INDArray denseDifference(Function<INDArray, INDArray> func, INDArray x0, INDArray f0, INDArray h, INDArray oneSided) {
        INDArray hVecs = Nd4j.diag(h.reshape(1L, h.length()));
        INDArray jTransposed = Nd4j.create(x0.length(), f0.length());
        int i = 0;
        while ((long)i < h.length()) {
            INDArray hVecI = hVecs.slice(i);
            INDArray x = x0.add(hVecI);
            INDArray dx = x.slice(i).sub(x0.slice(i));
            INDArray df = ((INDArray)func.apply((Object)x)).sub(f0);
            INDArray div = df.div(dx);
            jTransposed.putSlice(i, div);
            ++i;
        }
        if (f0.length() == 1L) {
            jTransposed = jTransposed.ravel();
        }
        return jTransposed;
    }
}

