/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Max;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.Min;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NormalizerMinMaxScaler
implements DataNormalization {
    private static Logger logger = LoggerFactory.getLogger(NormalizerMinMaxScaler.class);
    private int featureRank = 2;
    private INDArray featureMaxMin;
    private INDArray labelMaxMin;
    private INDArray featureMin;
    private INDArray featureMax;
    private INDArray labelMax;
    private INDArray labelMin;
    private boolean fitLabels = false;
    private double minRange;
    private double maxRange;

    public NormalizerMinMaxScaler(double minRange, double maxRange) {
        this.setMinRange(minRange);
        this.setMaxRange(maxRange);
    }

    public NormalizerMinMaxScaler() {
        this(0.0, 1.0);
    }

    public void setMinRange(double minRange) {
        this.minRange = minRange;
    }

    public void setMaxRange(double maxRange) {
        this.maxRange = maxRange;
    }

    @Override
    public void fit(DataSet dataSet) {
        this.featureRank = dataSet.getFeatures().rank();
        INDArray theFeatures = DataSetUtil.tailor2d(dataSet, true);
        this.featureMaxMin = this.fit(theFeatures);
        this.featureMin = this.featureMaxMin.getRow(0).dup();
        this.featureMax = this.featureMaxMin.getRow(1).dup();
        this.featureMaxMin = this.featureMax.sub(this.featureMin);
        if (this.fitLabels) {
            INDArray theLabels = DataSetUtil.tailor2d(dataSet, false);
            this.labelMaxMin = this.fit(theLabels);
            this.labelMin = this.labelMaxMin.getRow(0).dup();
            this.labelMax = this.labelMaxMin.getRow(1).dup();
            this.labelMaxMin = this.labelMax.sub(this.labelMin);
        }
    }

    private INDArray fit(INDArray theArray) {
        INDArray maxminhere = Nd4j.zeros(2, theArray.size(1));
        maxminhere.putRow(0, theArray.min(0));
        maxminhere.putRow(1, theArray.max(0));
        if (maxminhere.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: max val minus min val found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        return maxminhere;
    }

    @Override
    public void fit(DataSetIterator iterator) {
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            this.featureRank = next.getFeatures().rank();
            INDArray theFeatures = DataSetUtil.tailor2d(next, true);
            INDArray theLabels = null;
            if (this.fitLabels) {
                theLabels = DataSetUtil.tailor2d(next, false);
            }
            if (this.featureMin == null) {
                this.fit(next);
                continue;
            }
            INDArray nextMin = theFeatures.min(0);
            this.featureMin = Nd4j.getExecutioner().execAndReturn(new Min(nextMin, this.featureMin, this.featureMin, this.featureMin.length()));
            INDArray nextMax = theFeatures.max(0);
            this.featureMax = Nd4j.getExecutioner().execAndReturn(new Max(nextMax, this.featureMax, this.featureMax, this.featureMax.length()));
            if (!this.fitLabels) continue;
            nextMin = theLabels.min(0);
            this.labelMin = Nd4j.getExecutioner().execAndReturn(new Min(nextMin, this.labelMin, this.labelMin, this.labelMin.length()));
            nextMax = theLabels.max(0);
            this.labelMax = Nd4j.getExecutioner().execAndReturn(new Max(nextMax, this.labelMax, this.labelMax, this.labelMax.length()));
        }
        this.featureMaxMin = this.featureMax.sub(this.featureMin).add(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.featureMaxMin.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Feature max val minus min val found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        if (this.fitLabels) {
            this.labelMaxMin = this.labelMax.sub(this.labelMin).add(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
            if (this.labelMaxMin.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
                logger.info("API_INFO: Labels max val minus min val found to be zero. Transform will round upto epsilon to avoid nans.");
            }
        }
        iterator.reset();
    }

    @Override
    public void fitLabel(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    @Override
    public boolean isFitLabel() {
        return this.fitLabels;
    }

    @Override
    public void preProcess(DataSet toPreProcess) {
        if (this.featureMin == null || this.featureMax == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        if (this.maxRange - this.minRange < 0.0) {
            throw new RuntimeException("API_USE_ERROR: The given max value minus min value has to be greater than 0");
        }
        INDArray theFeatures = toPreProcess.getFeatures();
        INDArray theLabels = toPreProcess.getLabels();
        this.preProcess(theFeatures, true);
        if (this.fitLabels) {
            this.preProcess(theLabels, false);
        }
    }

    private void preProcess(INDArray theArray, boolean isFeatures) {
        INDArray max = isFeatures ? this.featureMax : this.labelMax;
        INDArray min = isFeatures ? this.featureMin : this.labelMin;
        INDArray maxmin = max.sub(min);
        if (theArray.rank() == 2) {
            theArray.subiRowVector(this.featureMin);
            theArray.diviRowVector(this.featureMaxMin.add(Nd4j.EPS_THRESHOLD));
            theArray.muli(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD);
            theArray.addi(this.minRange);
        } else {
            Nd4j.getExecutioner().execAndReturn(new BroadcastSubOp(theArray, min, theArray, 1));
            Nd4j.getExecutioner().execAndReturn(new BroadcastDivOp(theArray, maxmin, theArray, 1));
            theArray.muli(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD);
            theArray.addi(this.minRange);
        }
    }

    @Override
    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    @Override
    public void transform(INDArray theFeatures) {
        this.preProcess(theFeatures, true);
    }

    @Override
    public void transformLabel(INDArray labels) {
        this.preProcess(labels, false);
    }

    public void revertPreProcess(DataSet toPreProcess) {
        if (this.featureMin == null || this.featureMax == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        this.revertFeatures(toPreProcess.getFeatures());
        if (this.fitLabels) {
            this.revertLabels(toPreProcess.getLabels());
        }
    }

    @Override
    public void revert(DataSet toPreProcess) {
        this.revertPreProcess(toPreProcess);
    }

    @Override
    public void revertFeatures(INDArray features) {
        features.subi(this.minRange).divi(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD).muliRowVector(this.featureMaxMin).addiRowVector(this.featureMin);
    }

    @Override
    public void revertLabels(INDArray labels) {
        if (!this.fitLabels) {
            return;
        }
        labels.subi(this.minRange).divi(this.maxRange - this.minRange + Nd4j.EPS_THRESHOLD).muliRowVector(this.featureMaxMin).addiRowVector(this.featureMin);
    }

    public INDArray getMin() {
        if (this.featureMin == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.featureMin;
    }

    public INDArray getMax() {
        if (this.featureMax == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.featureMax;
    }

    public INDArray getLabelMin() {
        if (this.labelMin == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.labelMin;
    }

    public INDArray getLabelMax() {
        if (this.labelMax == null) {
            throw new RuntimeException("API_USE_ERROR: Preprocessors have to be explicitly fit before use. Usage: .fit(dataset) or .fit(datasetiterator)");
        }
        return this.labelMax;
    }

    @Override
    public void load(File ... statistics) throws IOException {
        this.featureMin = Nd4j.readBinary(statistics[0]);
        this.featureMax = Nd4j.readBinary(statistics[1]);
        this.featureMaxMin = this.featureMax.sub(this.featureMin);
        if (this.fitLabels) {
            this.labelMin = Nd4j.readBinary(statistics[0]);
            this.labelMax = Nd4j.readBinary(statistics[1]);
            this.labelMaxMin = this.labelMax.sub(this.labelMin);
        }
    }

    @Override
    public void save(File ... files) throws IOException {
        Nd4j.saveBinary(this.featureMin, files[0]);
        Nd4j.saveBinary(this.featureMax, files[1]);
        if (this.fitLabels) {
            Nd4j.saveBinary(this.labelMin, files[2]);
            Nd4j.saveBinary(this.labelMax, files[3]);
        }
    }
}

