package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.ml.linalg.BLAS$;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.Vectors$;
import org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator;
import scala.Array$;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;

/* compiled from: DifferentiableLossAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00154\u0011\"\u0001\u0002\u0011\u0002\u0007\u0005aA\u0004!\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;pe*\u00111\u0001B\u0001\u000bC\u001e<'/Z4bi>\u0014(BA\u0003\u0007\u0003\u0015y\u0007\u000f^5n\u0015\t9\u0001\"\u0001\u0002nY*\u0011\u0011BC\u0001\u0006gB\f'o\u001b\u0006\u0003\u00171\ta!\u00199bG\",'\"A\u0007\u0002\u0007=\u0014x-F\u0002\u0010\u0017\n\u001b2\u0001\u0001\t\u0017!\t\tB#D\u0001\u0013\u0015\u0005\u0019\u0012!B:dC2\f\u0017BA\u000b\u0013\u0005\u0019\te.\u001f*fMB\u0011\u0011cF\u0005\u00031I\u0011AbU3sS\u0006d\u0017N_1cY\u0016DQA\u0007\u0001\u0005\u0002q\ta\u0001J5oSR$3\u0001\u0001\u000b\u0002;A\u0011\u0011CH\u0005\u0003?I\u0011A!\u00168ji\"9\u0011\u0005\u0001a\u0001\n#\u0011\u0013!C<fS\u001eDGoU;n+\u0005\u0019\u0003CA\t%\u0013\t)#C\u0001\u0004E_V\u0014G.\u001a\u0005\bO\u0001\u0001\r\u0011\"\u0005)\u000359X-[4iiN+Xn\u0018\u0013fcR\u0011Q$\u000b\u0005\bU\u0019\n\t\u00111\u0001$\u0003\rAH%\r\u0005\bY\u0001\u0001\r\u0011\"\u0005#\u0003\u001dawn]:Tk6DqA\f\u0001A\u0002\u0013Eq&A\u0006m_N\u001c8+^7`I\u0015\fHCA\u000f1\u0011\u001dQS&!AA\u0002\rBqA\r\u0001C\u0002\u001bE1'A\u0002eS6,\u0012\u0001\u000e\t\u0003#UJ!A\u000e\n\u0003\u0007%sG\u000f\u0003\u00059\u0001!\u0015\r\u0011\"\u0005:\u0003A9'/\u00193jK:$8+^7BeJ\f\u00170F\u0001;!\r\t2hI\u0005\u0003yI\u0011Q!\u0011:sCfDQA\u0010\u0001\u0007\u0002}\n1!\u00193e)\t\u0001\u0015\u000b\u0005\u0002B\u00052\u0001A!B\"\u0001\u0005\u0004!%aA!hOF\u0011Q\t\u0013\t\u0003#\u0019K!a\u0012\n\u0003\u000f9{G\u000f[5oOB!\u0011\n\u0001&A\u001b\u0005\u0011\u0001CA!L\t\u0015a\u0005A1\u0001N\u0005\u0015!\u0015\r^;n#\t)e\n\u0005\u0002\u0012\u001f&\u0011\u0001K\u0005\u0002\u0004\u0003:L\b\"\u0002*>\u0001\u0004Q\u0015\u0001C5ogR\fgnY3\t\u000bQ\u0003A\u0011A+\u0002\u000b5,'oZ3\u0015\u0005\u00013\u0006\"B,T\u0001\u0004\u0001\u0015!B8uQ\u0016\u0014\b\"B-\u0001\t\u0003Q\u0016\u0001C4sC\u0012LWM\u001c;\u0016\u0003m\u0003\"\u0001X0\u000e\u0003uS!A\u0018\u0004\u0002\r1Lg.\u00197h\u0013\t\u0001WL\u0001\u0004WK\u000e$xN\u001d\u0005\u0006E\u0002!\tAI\u0001\u0007o\u0016Lw\r\u001b;\t\u000b\u0011\u0004A\u0011\u0001\u0012\u0002\t1|7o\u001d")
/* loaded from: input_file:org/apache/spark/ml/optim/aggregator/DifferentiableLossAggregator.class */
public interface DifferentiableLossAggregator<Datum, Agg extends DifferentiableLossAggregator<Datum, Agg>> extends Serializable {
    double weightSum();

    void weightSum_$eq(double d);

    double lossSum();

    void lossSum_$eq(double d);

    int dim();

    default double[] gradientSumArray() {
        return (double[]) Array$.MODULE$.ofDim(dim(), ClassTag$.MODULE$.Double());
    }

    Agg add(Datum datum);

    default Agg merge(Agg agg) {
        Predef$.MODULE$.require(dim() == agg.dim(), () -> {
            return new StringBuilder(46).append("Dimensions mismatch when merging with another ").append(new StringBuilder(22).append(this.getClass().getSimpleName()).append(". Expecting ").append(this.dim()).append(" but got ").append(agg.dim()).append(".").toString()).toString();
        });
        if (agg.weightSum() != 0) {
            weightSum_$eq(weightSum() + agg.weightSum());
            lossSum_$eq(lossSum() + agg.lossSum());
            double[] gradientSumArray = gradientSumArray();
            double[] gradientSumArray2 = agg.gradientSumArray();
            for (int i = 0; i < dim(); i++) {
                int i2 = i;
                gradientSumArray[i2] = gradientSumArray[i2] + gradientSumArray2[i];
            }
        }
        return this;
    }

    default Vector gradient() {
        Predef$.MODULE$.require(weightSum() > 0.0d, () -> {
            return new StringBuilder(44).append("The effective number of instances should be ").append(new StringBuilder(27).append("greater than 0.0, but was ").append(this.weightSum()).append(".").toString()).toString();
        });
        Vector dense = Vectors$.MODULE$.dense((double[]) gradientSumArray().clone());
        BLAS$.MODULE$.scal(1.0d / weightSum(), dense);
        return dense;
    }

    default double weight() {
        return weightSum();
    }

    default double loss() {
        Predef$.MODULE$.require(weightSum() > 0.0d, () -> {
            return new StringBuilder(44).append("The effective number of instances should be ").append(new StringBuilder(27).append("greater than 0.0, but was ").append(this.weightSum()).append(".").toString()).toString();
        });
        return lossSum() / weightSum();
    }

    static void $init$(DifferentiableLossAggregator differentiableLossAggregator) {
        differentiableLossAggregator.weightSum_$eq(0.0d);
        differentiableLossAggregator.lossSum_$eq(0.0d);
    }
}
