/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor;

import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.AbstractNormalizer;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStrategy;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;

public abstract class AbstractDataSetNormalizer<S extends NormalizerStats>
extends AbstractNormalizer
implements DataNormalization {
    protected NormalizerStrategy<S> strategy;
    private S featureStats;
    private S labelStats;
    private boolean fitLabels = false;

    protected AbstractDataSetNormalizer(NormalizerStrategy<S> strategy) {
        this.strategy = strategy;
    }

    @Override
    public void fitLabel(boolean fitLabels) {
        this.fitLabels = fitLabels;
    }

    @Override
    public boolean isFitLabel() {
        return this.fitLabels;
    }

    @Override
    public void fit(DataSet dataSet) {
        this.featureStats = this.newBuilder().addFeatures(dataSet).build();
        if (this.isFitLabel()) {
            this.labelStats = this.newBuilder().addLabels(dataSet).build();
        }
    }

    protected S getFeatureStats() {
        this.assertIsFit();
        return this.featureStats;
    }

    protected S getLabelStats() {
        this.assertIsFit();
        return this.labelStats;
    }

    @Override
    protected boolean isFit() {
        return this.featureStats != null;
    }

    @Override
    public void fit(DataSetIterator iterator) {
        NormalizerStats.Builder featureNormBuilder = this.newBuilder();
        NormalizerStats.Builder labelNormBuilder = this.newBuilder();
        iterator.reset();
        while (iterator.hasNext()) {
            DataSet next = (DataSet)iterator.next();
            featureNormBuilder.addFeatures(next);
            if (!this.fitLabels) continue;
            labelNormBuilder.addLabels(next);
        }
        this.featureStats = featureNormBuilder.build();
        if (this.fitLabels) {
            this.labelStats = labelNormBuilder.build();
        }
        iterator.reset();
    }

    protected abstract NormalizerStats.Builder newBuilder();

    @Override
    public void preProcess(@NonNull DataSet toPreProcess) {
        if (toPreProcess == null) {
            throw new NullPointerException("toPreProcess");
        }
        this.transform(toPreProcess.getFeatures(), toPreProcess.getFeaturesMaskArray());
        this.transformLabel(toPreProcess.getLabels(), toPreProcess.getLabelsMaskArray());
    }

    @Override
    public void transform(DataSet toPreProcess) {
        this.preProcess(toPreProcess);
    }

    @Override
    public void transform(INDArray features) {
        this.transform(features, null);
    }

    @Override
    public void transform(INDArray features, INDArray featuresMask) {
        this.strategy.preProcess(features, featuresMask, this.getFeatureStats());
    }

    @Override
    public void transformLabel(INDArray label) {
        this.transformLabel(label, null);
    }

    @Override
    public void transformLabel(INDArray label, INDArray labelsMask) {
        if (this.isFitLabel()) {
            this.strategy.preProcess(label, labelsMask, this.getLabelStats());
        }
    }

    @Override
    public void revertFeatures(INDArray features) {
        this.revertFeatures(features, null);
    }

    @Override
    public void revertFeatures(INDArray features, INDArray featuresMask) {
        this.strategy.revert(features, featuresMask, this.getFeatureStats());
    }

    @Override
    public void revertLabels(INDArray labels) {
        this.revertLabels(labels, null);
    }

    @Override
    public void revertLabels(INDArray labels, INDArray labelsMask) {
        if (this.isFitLabel()) {
            this.strategy.revert(labels, labelsMask, this.getLabelStats());
        }
    }

    @Override
    public void revert(DataSet data) {
        this.revertFeatures(data.getFeatures(), data.getFeaturesMaskArray());
        this.revertLabels(data.getLabels(), data.getLabelsMaskArray());
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AbstractDataSetNormalizer)) {
            return false;
        }
        AbstractDataSetNormalizer other = (AbstractDataSetNormalizer)o;
        if (!other.canEqual(this)) {
            return false;
        }
        NormalizerStrategy<S> this$strategy = this.strategy;
        NormalizerStrategy<S> other$strategy = other.strategy;
        if (this$strategy == null ? other$strategy != null : !this$strategy.equals(other$strategy)) {
            return false;
        }
        S this$featureStats = this.getFeatureStats();
        S other$featureStats = other.getFeatureStats();
        if (this$featureStats == null ? other$featureStats != null : !this$featureStats.equals(other$featureStats)) {
            return false;
        }
        S this$labelStats = this.getLabelStats();
        S other$labelStats = other.getLabelStats();
        if (this$labelStats == null ? other$labelStats != null : !this$labelStats.equals(other$labelStats)) {
            return false;
        }
        return this.fitLabels == other.fitLabels;
    }

    protected boolean canEqual(Object other) {
        return other instanceof AbstractDataSetNormalizer;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        NormalizerStrategy<S> $strategy = this.strategy;
        result = result * 59 + ($strategy == null ? 43 : $strategy.hashCode());
        S $featureStats = this.getFeatureStats();
        result = result * 59 + ($featureStats == null ? 43 : $featureStats.hashCode());
        S $labelStats = this.getLabelStats();
        result = result * 59 + ($labelStats == null ? 43 : $labelStats.hashCode());
        result = result * 59 + (this.fitLabels ? 79 : 97);
        return result;
    }

    protected void setFeatureStats(S featureStats) {
        this.featureStats = featureStats;
    }

    protected void setLabelStats(S labelStats) {
        this.labelStats = labelStats;
    }
}

