/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Max;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Min;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;

public class NormalizerMinMaxScaler
implements DataSetPreProcessor {
    private INDArray min;
    private INDArray max;
    private INDArray maxMinusMin;
    private double minRange;
    private double maxRange;

    public NormalizerMinMaxScaler(double minRange, double maxRange) {
        this.setMinRange(minRange);
        this.setMaxRange(maxRange);
    }

    public NormalizerMinMaxScaler() {
        this(0.0, 1.0);
    }

    public void setMinRange(double minRange) {
        this.minRange = minRange;
    }

    public void setMaxRange(double maxRange) {
        this.maxRange = maxRange;
    }

    public void fit(DataSet dataSet) {
        this.min = dataSet.getFeatureMatrix().min(0);
        this.max = dataSet.getFeatureMatrix().max(0);
        this.maxMinusMin = this.max.sub(this.min);
    }

    public void fit(DataSetIterator iterator) {
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            if (this.min == null) {
                this.fit(next);
                continue;
            }
            INDArray nextMin = next.getFeatureMatrix().min(0);
            this.min = Nd4j.getExecutioner().execAndReturn(new Min(nextMin, this.min, this.min, this.min.length()));
            INDArray nextMax = next.getFeatureMatrix().max(0);
            this.max = Nd4j.getExecutioner().execAndReturn(new Max(nextMax, this.max, this.max, this.max.length()));
        }
        this.maxMinusMin = this.max.sub(this.min).add(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        iterator.reset();
    }

    @Override
    public void preProcess(DataSet toPreProcess) {
        if (this.min == null || this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        toPreProcess.setFeatures(toPreProcess.getFeatures().subRowVector(this.min));
        toPreProcess.setFeatures(toPreProcess.getFeatures().divRowVector(this.maxMinusMin));
        toPreProcess.setFeatures(toPreProcess.getFeatures().div(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD));
        toPreProcess.setFeatures(toPreProcess.getFeatures().add(this.minRange));
    }

    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    public void transform(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.preProcess((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public void revertPreProcess(DataSet toPreProcess) {
        if (this.min == null || this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        toPreProcess.setFeatures(toPreProcess.getFeatures().sub(this.minRange));
        toPreProcess.setFeatures(toPreProcess.getFeatures().mul(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD));
        toPreProcess.setFeatures(toPreProcess.getFeatures().mulRowVector(this.maxMinusMin));
        toPreProcess.setFeatures(toPreProcess.getFeatures().addRowVector(this.min));
    }

    public void revert(DataSet toPreProcess) {
        this.revertPreProcess(toPreProcess);
    }

    public void revert(DataSetIterator toPreProcessIter) {
        while (toPreProcessIter.hasNext()) {
            this.revertPreProcess((DataSet)toPreProcessIter.next());
        }
        toPreProcessIter.reset();
    }

    public INDArray getMin() {
        if (this.min == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.min;
    }

    public INDArray getMax() {
        if (this.max == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.max;
    }

    public void load(File min, File max) throws IOException {
        this.min = Nd4j.readBinary(min);
        this.max = Nd4j.readBinary(max);
    }

    public void save(File min, File max) throws IOException {
        Nd4j.saveBinary(this.min, min);
        Nd4j.saveBinary(this.max, max);
    }
}

