/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DistributionStats;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;

public class NormalizerStandardize
extends AbstractNormalizerStandardize
implements DataNormalization {
    private DistributionStats featureStats;
    private DistributionStats labelStats;
    private boolean fitLabels = false;

    public NormalizerStandardize() {
    }

    public NormalizerStandardize(INDArray featureMean, INDArray featureStd) {
        this.featureStats = new DistributionStats(featureMean, featureStd);
        this.fitLabels = false;
    }

    public NormalizerStandardize(INDArray featureMean, INDArray featureStd, INDArray labelMean, INDArray labelStd) {
        this.featureStats = new DistributionStats(featureMean, featureStd);
        this.labelStats = new DistributionStats(labelMean, labelStd);
        this.fitLabels = true;
    }

    @Override
    public void fitLabel(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    @Override
    public boolean isFitLabel() {
        return this.fitLabels;
    }

    @Override
    public void fit(@NonNull DataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        this.featureStats = new DistributionStats.Builder().addFeatures(dataSet).build();
        if (this.fitLabels) {
            this.labelStats = new DistributionStats.Builder().addLabels(dataSet).build();
        }
    }

    @Override
    public void fit(@NonNull DataSetIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        DistributionStats.Builder featureNormBuilder = new DistributionStats.Builder();
        DistributionStats.Builder labelNormBuilder = new DistributionStats.Builder();
        iterator.reset();
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            featureNormBuilder.addFeatures(next);
            if (!this.fitLabels) continue;
            labelNormBuilder.addLabels(next);
        }
        this.featureStats = featureNormBuilder.build();
        if (this.fitLabels) {
            this.labelStats = labelNormBuilder.build();
        }
        iterator.reset();
    }

    @Override
    public void preProcess(@NonNull DataSet toPreProcess) {
        if (toPreProcess == null) {
            throw new NullPointerException("toPreProcess");
        }
        this.assertIsFit();
        this.preProcess(toPreProcess.getFeatures(), this.featureStats);
        if (this.fitLabels) {
            this.preProcess(toPreProcess.getLabels(), this.labelStats);
        }
    }

    @Override
    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    @Override
    public void transform(INDArray theFeatures) {
        this.transform(theFeatures, true);
    }

    @Override
    public void transformLabel(INDArray label) {
        this.transform(label, false);
    }

    private void transform(INDArray theArray, boolean isFeatures) {
        this.preProcess(theArray, isFeatures ? this.featureStats : this.labelStats);
    }

    @Override
    public void revert(DataSet data) {
        this.assertIsFit();
        this.revert(data.getFeatures(), this.featureStats);
        if (this.fitLabels) {
            this.revert(data.getLabels(), this.labelStats);
        }
    }

    @Override
    public void revertFeatures(INDArray features) {
        this.revert(features, this.featureStats);
    }

    @Override
    public void revertLabels(INDArray labels) {
        if (!this.fitLabels) {
            return;
        }
        this.revert(labels, this.labelStats);
    }

    public INDArray getMean() {
        this.assertIsFit();
        return this.featureStats.getMean();
    }

    public INDArray getLabelMean() {
        this.assertIsFit();
        return this.labelStats.getMean();
    }

    public INDArray getStd() {
        this.assertIsFit();
        return this.featureStats.getStd();
    }

    public INDArray getLabelStd() {
        this.assertIsFit();
        return this.labelStats.getStd();
    }

    @Override
    protected boolean isFit() {
        return this.featureStats != null;
    }

    @Override
    public void load(File ... files) throws IOException {
        this.featureStats = DistributionStats.load(files[0], files[1]);
        if (this.fitLabels) {
            this.labelStats = DistributionStats.load(files[2], files[3]);
        }
    }

    @Override
    public void save(File ... files) throws IOException {
        this.featureStats.save(files[0], files[1]);
        if (this.fitLabels) {
            this.labelStats.save(files[2], files[3]);
        }
    }
}

