/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.listeners;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import lombok.NonNull;
import org.nd4j.base.Preconditions;

public class Loss {
    private final List<String> lossNames;
    private final double[] losses;

    public Loss(@NonNull List<String> lossNames, @NonNull double[] losses) {
        if (lossNames == null) {
            throw new NullPointerException("lossNames is marked @NonNull but is null");
        }
        if (losses == null) {
            throw new NullPointerException("losses is marked @NonNull but is null");
        }
        Preconditions.checkState((lossNames.size() == losses.length ? 1 : 0) != 0, (String)"Expected equal number of loss names and loss values");
        this.lossNames = lossNames;
        this.losses = losses;
    }

    public int numLosses() {
        return this.lossNames.size();
    }

    public List<String> lossNames() {
        return this.lossNames;
    }

    public double[] lossValues() {
        return this.losses;
    }

    public double getLoss(@NonNull String lossName) {
        if (lossName == null) {
            throw new NullPointerException("lossName is marked @NonNull but is null");
        }
        int idx = this.lossNames.indexOf(lossName);
        Preconditions.checkState((idx >= 0 ? 1 : 0) != 0, (String)"No loss with name \"%s\" exists. All loss names: %s", (Object)lossName, this.lossNames);
        return this.losses[idx];
    }

    public double totalLoss() {
        double sum = 0.0;
        for (double d : this.losses) {
            sum += d;
        }
        return sum;
    }

    public Loss copy() {
        return new Loss(this.lossNames, this.losses);
    }

    public static Loss sum(List<Loss> losses) {
        if (losses.size() == 0) {
            return new Loss(Collections.emptyList(), new double[0]);
        }
        double[] lossValues = new double[losses.get((int)0).losses.length];
        ArrayList<String> lossNames = new ArrayList<String>(losses.get((int)0).lossNames);
        for (int i = 0; i < losses.size(); ++i) {
            Loss l = losses.get(i);
            Preconditions.checkState((l.losses.length == lossValues.length ? 1 : 0) != 0, (String)"Loss %s has %s losses, the others before it had %s.", (int)i, (int)l.losses.length, (int)lossValues.length);
            Preconditions.checkState((boolean)l.lossNames.equals(lossNames), (String)"Loss %s has different loss names from the others before it.  Expected %s, got %s.", (Object)i, lossNames, l.lossNames);
            for (int j = 0; j < lossValues.length; ++j) {
                int n = j;
                lossValues[n] = lossValues[n] + l.losses[j];
            }
        }
        return new Loss(lossNames, lossValues);
    }

    public static Loss average(List<Loss> losses) {
        Loss sum = Loss.sum(losses);
        int i = 0;
        while (i < sum.losses.length) {
            int n = i++;
            sum.losses[n] = sum.losses[n] / (double)losses.size();
        }
        return sum;
    }

    public static Loss add(Loss a, Loss b) {
        Preconditions.checkState((boolean)a.lossNames.equals(b.lossNames), (String)"Loss names differ.  First loss has names %s, second has names %s.", a.lossNames, b.lossNames);
        double[] lossValues = new double[a.losses.length];
        for (int i = 0; i < lossValues.length; ++i) {
            lossValues[i] = a.losses[i] + b.losses[i];
        }
        return new Loss(a.lossNames, lossValues);
    }

    public static Loss sub(Loss a, Loss b) {
        Preconditions.checkState((boolean)a.lossNames.equals(b.lossNames), (String)"Loss names differ.  First loss has names %s, second has names %s.", a.lossNames, b.lossNames);
        double[] lossValues = new double[a.losses.length];
        for (int i = 0; i < lossValues.length; ++i) {
            lossValues[i] = a.losses[i] - b.losses[i];
        }
        return new Loss(a.lossNames, lossValues);
    }

    public static Loss div(Loss a, Number b) {
        double[] lossValues = new double[a.losses.length];
        for (int i = 0; i < lossValues.length; ++i) {
            lossValues[i] = a.losses[i] / b.doubleValue();
        }
        return new Loss(a.lossNames, lossValues);
    }

    public Loss add(Loss other) {
        return Loss.add(this, other);
    }

    public Loss sub(Loss other) {
        return Loss.sub(this, other);
    }

    public Loss plus(Loss other) {
        return Loss.add(this, other);
    }

    public Loss minus(Loss other) {
        return Loss.sub(this, other);
    }

    public Loss div(Number other) {
        return Loss.div(this, other);
    }

    public List<String> getLossNames() {
        return this.lossNames;
    }

    public double[] getLosses() {
        return this.losses;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Loss)) {
            return false;
        }
        Loss other = (Loss)o;
        if (!other.canEqual(this)) {
            return false;
        }
        List<String> this$lossNames = this.getLossNames();
        List<String> other$lossNames = other.getLossNames();
        if (this$lossNames == null ? other$lossNames != null : !((Object)this$lossNames).equals(other$lossNames)) {
            return false;
        }
        return Arrays.equals(this.getLosses(), other.getLosses());
    }

    protected boolean canEqual(Object other) {
        return other instanceof Loss;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        List<String> $lossNames = this.getLossNames();
        result = result * 59 + ($lossNames == null ? 43 : ((Object)$lossNames).hashCode());
        result = result * 59 + Arrays.hashCode(this.getLosses());
        return result;
    }

    public String toString() {
        return "Loss(lossNames=" + this.getLossNames() + ", losses=" + Arrays.toString(this.getLosses()) + ")";
    }
}

