/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DistributionStats;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.MultiDataSetPreProcessor;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizerStandardize;

public class MultiNormalizerStandardize
extends AbstractNormalizerStandardize
implements MultiDataSetPreProcessor {
    private List<DistributionStats> featureStats;
    private List<DistributionStats> labelStats;
    private boolean fitLabels = false;

    public void fitLabel(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    public boolean isFitLabel() {
        return this.fitLabels;
    }

    public void fit(@NonNull MultiDataSet dataSet) {
        if (dataSet == null) {
            throw new NullPointerException("dataSet");
        }
        ArrayList<DistributionStats.Builder> featureNormBuilders = new ArrayList<DistributionStats.Builder>();
        ArrayList<DistributionStats.Builder> labelNormBuilders = new ArrayList<DistributionStats.Builder>();
        this.fitPartial(dataSet, featureNormBuilders, labelNormBuilders);
        this.featureStats = DistributionStats.Builder.buildList(featureNormBuilders);
        if (this.fitLabels) {
            this.labelStats = DistributionStats.Builder.buildList(labelNormBuilders);
        }
    }

    public void fit(@NonNull MultiDataSetIterator iterator) {
        if (iterator == null) {
            throw new NullPointerException("iterator");
        }
        ArrayList<DistributionStats.Builder> featureNormBuilders = new ArrayList<DistributionStats.Builder>();
        ArrayList<DistributionStats.Builder> labelNormBuilders = new ArrayList<DistributionStats.Builder>();
        iterator.reset();
        while (iterator.hasNext()) {
            MultiDataSet next = (MultiDataSet)iterator.next();
            this.fitPartial(next, featureNormBuilders, labelNormBuilders);
        }
        this.featureStats = DistributionStats.Builder.buildList(featureNormBuilders);
        if (this.fitLabels) {
            this.labelStats = DistributionStats.Builder.buildList(labelNormBuilders);
        }
    }

    private void fitPartial(MultiDataSet dataSet, List<DistributionStats.Builder> featureStatsBuilders, List<DistributionStats.Builder> labelStatsBuilders) {
        int i;
        int numInputs = dataSet.getFeatures().length;
        int numOutputs = dataSet.getLabels().length;
        this.ensureStatsBuilders(featureStatsBuilders, numInputs);
        this.ensureStatsBuilders(labelStatsBuilders, numOutputs);
        for (i = 0; i < numInputs; ++i) {
            featureStatsBuilders.get(i).add(dataSet.getFeatures(i), dataSet.getFeaturesMaskArray(i));
        }
        if (this.fitLabels) {
            for (i = 0; i < numOutputs; ++i) {
                labelStatsBuilders.get(i).add(dataSet.getLabels(i), dataSet.getLabelsMaskArray(i));
            }
        }
    }

    private void ensureStatsBuilders(List<DistributionStats.Builder> builders, int amount) {
        if (builders.isEmpty()) {
            for (int i = 0; i < amount; ++i) {
                builders.add(new DistributionStats.Builder());
            }
        }
    }

    @Override
    public void preProcess(@NonNull MultiDataSet toPreProcess) {
        int i;
        if (toPreProcess == null) {
            throw new NullPointerException("toPreProcess");
        }
        this.assertIsFit();
        int numFeatures = toPreProcess.getFeatures().length;
        int numLabels = toPreProcess.getLabels().length;
        for (i = 0; i < numFeatures; ++i) {
            this.preProcess(toPreProcess.getFeatures(i), this.featureStats.get(i));
        }
        if (this.fitLabels) {
            for (i = 0; i < numLabels; ++i) {
                this.preProcess(toPreProcess.getLabels(i), this.labelStats.get(i));
            }
        }
    }

    public void revert(@NonNull MultiDataSet data) {
        if (data == null) {
            throw new NullPointerException("data");
        }
        this.assertIsFit();
        INDArray[] inputs = data.getFeatures();
        for (int i = 0; i < inputs.length; ++i) {
            this.revert(inputs[i], this.featureStats.get(i));
        }
        if (this.fitLabels) {
            INDArray[] outputs = data.getLabels();
            for (int i = 0; i < outputs.length; ++i) {
                this.revert(outputs[i], this.labelStats.get(i));
            }
        }
    }

    public INDArray getFeatureMean(int input) {
        this.assertIsFit();
        return this.featureStats.get(input).getMean();
    }

    public INDArray getLabelMean(int output) {
        this.assertIsFit();
        return this.labelStats.get(output).getMean();
    }

    public INDArray getFeatureStd(int input) {
        this.assertIsFit();
        return this.featureStats.get(input).getStd();
    }

    public INDArray getLabelStd(int output) {
        this.assertIsFit();
        return this.labelStats.get(output).getStd();
    }

    @Override
    protected boolean isFit() {
        return this.featureStats != null;
    }

    public void load(@NonNull List<File> featureFiles, @NonNull List<File> labelFiles) throws IOException {
        if (featureFiles == null) {
            throw new NullPointerException("featureFiles");
        }
        if (labelFiles == null) {
            throw new NullPointerException("labelFiles");
        }
        this.featureStats = this.load(featureFiles);
        if (this.fitLabels) {
            this.labelStats = this.load(labelFiles);
        }
    }

    private List<DistributionStats> load(List<File> files) throws IOException {
        ArrayList<DistributionStats> stats = new ArrayList<DistributionStats>(files.size() / 2);
        for (int i = 0; i < files.size() / 2; ++i) {
            stats.add(DistributionStats.load(files.get(i * 2), files.get(i * 2 + 1)));
        }
        return stats;
    }

    public void save(@NonNull List<File> featureFiles, @NonNull List<File> labelFiles) throws IOException {
        if (featureFiles == null) {
            throw new NullPointerException("featureFiles");
        }
        if (labelFiles == null) {
            throw new NullPointerException("labelFiles");
        }
        this.saveStats(this.featureStats, featureFiles);
        if (this.fitLabels) {
            this.saveStats(this.labelStats, labelFiles);
        }
    }

    private void saveStats(List<DistributionStats> stats, List<File> files) throws IOException {
        int requiredFiles = stats.size() * 2;
        if (requiredFiles != files.size()) {
            throw new RuntimeException(String.format("Need twice as many files as inputs / outputs (%d), got %d", requiredFiles, files.size()));
        }
        for (int i = 0; i < stats.size(); ++i) {
            stats.get(i).save(files.get(i * 2), files.get(i * 2 + 1));
        }
    }
}

