/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.iterator;

import java.io.File;
import java.io.IOException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Deprecated
public class StandardScaler {
    private static Logger logger = LoggerFactory.getLogger(StandardScaler.class);
    private INDArray mean;
    private INDArray std;
    private int runningTotal = 0;
    private int batchCount = 0;

    public void fit(DataSet dataSet) {
        this.mean = dataSet.getFeatureMatrix().mean(0);
        this.std = dataSet.getFeatureMatrix().std(0);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
    }

    public void fit(DataSetIterator iterator) {
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            this.runningTotal += next.numExamples();
            this.batchCount = next.getFeatures().size(0);
            if (this.mean == null) {
                this.mean = next.getFeatureMatrix().mean(0);
                this.std = this.batchCount == 1 ? Nd4j.zeros(this.mean.shape()) : Transforms.pow(next.getFeatureMatrix().std(0), 2);
                this.std.muli(this.batchCount);
                continue;
            }
            INDArray xMinusMean = next.getFeatureMatrix().subRowVector(this.mean);
            INDArray newMean = this.mean.add(xMinusMean.sum(0).divi(this.runningTotal));
            INDArray meanB = next.getFeatureMatrix().mean(0);
            INDArray deltaSq = Transforms.pow(meanB.subRowVector(this.mean), 2);
            INDArray deltaSqScaled = deltaSq.mul(Float.valueOf(((float)this.runningTotal - (float)this.batchCount) * (float)this.batchCount / (float)this.runningTotal));
            INDArray mtwoB = Transforms.pow(next.getFeatureMatrix().std(0), 2);
            mtwoB.muli(this.batchCount);
            this.std = this.std.add(mtwoB);
            this.std = this.std.add(deltaSqScaled);
            this.mean = newMean;
        }
        this.std.divi(this.runningTotal);
        this.std = Transforms.sqrt(this.std);
        this.std.addi(Nd4j.scalar(Nd4j.EPS_THRESHOLD));
        if (this.std.min(1) == Nd4j.scalar(Nd4j.EPS_THRESHOLD)) {
            logger.info("API_INFO: Std deviation found to be zero. Transform will round upto epsilon to avoid nans.");
        }
        iterator.reset();
    }

    public void load(File mean, File std) throws IOException {
        this.mean = Nd4j.readBinary(mean);
        this.std = Nd4j.readBinary(std);
    }

    public void save(File mean, File std) throws IOException {
        Nd4j.saveBinary(this.mean, mean);
        Nd4j.saveBinary(this.std, std);
    }

    public void transform(DataSet dataSet) {
        dataSet.setFeatures(dataSet.getFeatures().subRowVector(this.mean));
        dataSet.setFeatures(dataSet.getFeatures().divRowVector(this.std));
    }

    public INDArray getMean() {
        return this.mean;
    }

    public INDArray getStd() {
        return this.std;
    }
}

