/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.dataset.api.preprocessor.stats;

import java.util.Arrays;
import lombok.NonNull;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.DataSetUtil;
import org.nd4j.linalg.dataset.api.preprocessor.stats.NormalizerStats;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MinMaxStats
implements NormalizerStats {
    private static final Logger log = LoggerFactory.getLogger(MinMaxStats.class);
    private final INDArray lower;
    private final INDArray upper;
    private INDArray range;

    public MinMaxStats(@NonNull INDArray lower, @NonNull INDArray upper) {
        if (lower == null) {
            throw new NullPointerException("lower is marked @NonNull but is null");
        }
        if (upper == null) {
            throw new NullPointerException("upper is marked @NonNull but is null");
        }
        INDArray diff = upper.sub(lower);
        INDArray addedPadding = Transforms.max(diff, Nd4j.EPS_THRESHOLD).subi(diff);
        if (addedPadding.sumNumber().doubleValue() > 0.0) {
            log.info("NormalizerMinMaxScaler: max val minus min val found to be zero. Transform will round up to epsilon to avoid nans.");
            upper.addi(addedPadding);
        }
        this.lower = lower;
        this.upper = upper;
    }

    public INDArray getRange() {
        if (this.range == null) {
            try (MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                this.range = this.upper.sub(this.lower);
            }
        }
        return this.range;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof MinMaxStats)) {
            return false;
        }
        MinMaxStats other = (MinMaxStats)o;
        if (!other.canEqual(this)) {
            return false;
        }
        INDArray this$lower = this.getLower();
        INDArray other$lower = other.getLower();
        if (this$lower == null ? other$lower != null : !this$lower.equals(other$lower)) {
            return false;
        }
        INDArray this$upper = this.getUpper();
        INDArray other$upper = other.getUpper();
        if (this$upper == null ? other$upper != null : !this$upper.equals(other$upper)) {
            return false;
        }
        INDArray this$range = this.getRange();
        INDArray other$range = other.getRange();
        return !(this$range == null ? other$range != null : !this$range.equals(other$range));
    }

    protected boolean canEqual(Object other) {
        return other instanceof MinMaxStats;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        INDArray $lower = this.getLower();
        result = result * 59 + ($lower == null ? 43 : $lower.hashCode());
        INDArray $upper = this.getUpper();
        result = result * 59 + ($upper == null ? 43 : $upper.hashCode());
        INDArray $range = this.getRange();
        result = result * 59 + ($range == null ? 43 : $range.hashCode());
        return result;
    }

    public INDArray getLower() {
        return this.lower;
    }

    public INDArray getUpper() {
        return this.upper;
    }

    public static class Builder
    implements NormalizerStats.Builder<MinMaxStats> {
        private INDArray runningLower;
        private INDArray runningUpper;

        public Builder addFeatures(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet is marked @NonNull but is null");
            }
            return this.add(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
        }

        public Builder addLabels(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet is marked @NonNull but is null");
            }
            return this.add(dataSet.getLabels(), dataSet.getLabelsMaskArray());
        }

        public Builder add(@NonNull INDArray data, INDArray mask) {
            if (data == null) {
                throw new NullPointerException("data is marked @NonNull but is null");
            }
            if ((data = DataSetUtil.tailor2d(data, mask)) == null) {
                return this;
            }
            INDArray batchMin = data.min(0).reshape(1L, data.size(1));
            INDArray batchMax = data.max(0).reshape(1L, data.size(1));
            if (!Arrays.equals(batchMin.shape(), batchMax.shape())) {
                throw new IllegalStateException("Data min and max must be same shape. Likely a bug in the operation changing the input?");
            }
            if (this.runningLower == null) {
                this.runningLower = batchMin.dup();
                this.runningUpper = batchMax.dup();
            } else {
                Transforms.min(this.runningLower, batchMin, false);
                Transforms.max(this.runningUpper, batchMax, false);
            }
            return this;
        }

        @Override
        public MinMaxStats build() {
            if (this.runningLower == null) {
                throw new RuntimeException("No data was added, statistics cannot be determined");
            }
            try (MemoryWorkspace workspace = Nd4j.getMemoryManager().scopeOutOfWorkspaces();){
                MinMaxStats minMaxStats = new MinMaxStats(this.runningLower.dup(), this.runningUpper.dup());
                return minMaxStats;
            }
        }
    }
}

