/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Max;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Min;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NormalizerMinMaxScaler
implements DataNormalization {
    private static Logger logger = LoggerFactory.getLogger(NormalizerMinMaxScaler.class);
    private INDArray min;
    private INDArray max;
    private INDArray maxMinusMin;
    private double minRange;
    private double maxRange;

    public NormalizerMinMaxScaler(double minRange, double maxRange) {
        this.setMinRange(minRange);
        this.setMaxRange(maxRange);
    }

    public NormalizerMinMaxScaler() {
        this(0.0, 1.0);
    }

    public void setMinRange(double minRange) {
        this.minRange = minRange;
    }

    public void setMaxRange(double maxRange) {
        this.maxRange = maxRange;
    }

    @Override
    public void fit(DataSet dataSet) {
        this.min = dataSet.getFeatureMatrix().min(0);
        this.max = dataSet.getFeatureMatrix().max(0);
        this.maxMinusMin = this.max.sub(this.min);
        this.maxMinusMin.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.maxMinusMin.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: max val minus min val found to be zero. Transform will round upto epsilon to avoid nans.");
        }
    }

    @Override
    public void fit(DataSetIterator iterator) {
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            if (this.min == null) {
                this.fit(next);
                continue;
            }
            INDArray nextMin = next.getFeatureMatrix().min(0);
            this.min = Nd4j.getExecutioner().execAndReturn(new Min(nextMin, this.min, this.min, this.min.length()));
            INDArray nextMax = next.getFeatureMatrix().max(0);
            this.max = Nd4j.getExecutioner().execAndReturn(new Max(nextMax, this.max, this.max, this.max.length()));
        }
        this.maxMinusMin = this.max.sub(this.min).add(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.maxMinusMin.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: max val minus min val found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        iterator.reset();
    }

    @Override
    public void preProcess(DataSet toPreProcess) {
        if (this.min == null || this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        if (this.maxRange - this.minRange < 0.0) {
            throw new RuntimeException("API_USE_ERROR: The given max value minus min value has to be greater than 0");
        }
        INDArray theFeatures = toPreProcess.getFeatures();
        this.preProcess(theFeatures);
    }

    public void preProcess(INDArray theFeatures) {
        theFeatures.subiRowVector(this.min);
        theFeatures.diviRowVector(this.maxMinusMin);
        theFeatures.divi(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD);
        theFeatures.addi(this.minRange);
    }

    @Override
    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    public void transform(INDArray theFeatures) {
        this.preProcess(theFeatures);
    }

    public void revertPreProcess(DataSet toPreProcess) {
        if (this.min == null || this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        toPreProcess.getFeatures().subi(this.minRange);
        toPreProcess.getFeatures().muli(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD);
        toPreProcess.getFeatures().muliRowVector(this.maxMinusMin);
        toPreProcess.getFeatures().addiRowVector(this.min);
    }

    public void revert(DataSet toPreProcess) {
        this.revertPreProcess(toPreProcess);
    }

    public void revert(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.revertPreProcess((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public INDArray getMin() {
        if (this.min == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.min;
    }

    public INDArray getMax() {
        if (this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.max;
    }

    @Override
    public void load(File ... statistics) throws IOException {
        this.min = Nd4j.readBinary(statistics[0]);
        this.max = Nd4j.readBinary(statistics[1]);
        this.maxMinusMin = this.max.sub(this.min);
    }

    @Override
    public void save(File ... files) throws IOException {
        Nd4j.saveBinary(this.min, files[0]);
        Nd4j.saveBinary(this.max, files[1]);
    }
}

