package org.apache.spark.ml.optim.aggregator;

import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.ml.feature.Instance;
import org.apache.spark.ml.linalg.Vector;
import scala.MatchError;
import scala.Predef$;
import scala.collection.mutable.ArrayOps;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.runtime.DoubleRef;

/* compiled from: HuberAggregator.scala */
@ScalaSignature(bytes = "\u0006\u0001\u00054Q!\u0001\u0002\u0001\r9\u0011q\u0002S;cKJ\fum\u001a:fO\u0006$xN\u001d\u0006\u0003\u0007\u0011\t!\"Y4he\u0016<\u0017\r^8s\u0015\t)a!A\u0003paRLWN\u0003\u0002\b\u0011\u0005\u0011Q\u000e\u001c\u0006\u0003\u0013)\tQa\u001d9be.T!a\u0003\u0007\u0002\r\u0005\u0004\u0018m\u00195f\u0015\u0005i\u0011aA8sON\u0019\u0001aD\u000b\u0011\u0005A\u0019R\"A\t\u000b\u0003I\tQa]2bY\u0006L!\u0001F\t\u0003\r\u0005s\u0017PU3g!\u00111r#G\u0010\u000e\u0003\tI!\u0001\u0007\u0002\u00039\u0011KgMZ3sK:$\u0018.\u00192mK2{7o]!hOJ,w-\u0019;peB\u0011!$H\u0007\u00027)\u0011ADB\u0001\bM\u0016\fG/\u001e:f\u0013\tq2D\u0001\u0005J]N$\u0018M\\2f!\t1\u0002\u0001\u0003\u0005\"\u0001\t\u0005\t\u0015!\u0003$\u000311\u0017\u000e^%oi\u0016\u00148-\u001a9u\u0007\u0001\u0001\"\u0001\u0005\u0013\n\u0005\u0015\n\"a\u0002\"p_2,\u0017M\u001c\u0005\tO\u0001\u0011\t\u0011)A\u0005Q\u00059Q\r]:jY>t\u0007C\u0001\t*\u0013\tQ\u0013C\u0001\u0004E_V\u0014G.\u001a\u0005\tY\u0001\u0011\t\u0011)A\u0005[\u0005i!m\u0019$fCR,(/Z:Ti\u0012\u00042AL\u00194\u001b\u0005y#B\u0001\u0019\t\u0003%\u0011'o\\1eG\u0006\u001cH/\u0003\u00023_\tI!I]8bI\u000e\f7\u000f\u001e\t\u0004!QB\u0013BA\u001b\u0012\u0005\u0015\t%O]1z\u0011!9\u0004A!A!\u0002\u0013A\u0014\u0001\u00042d!\u0006\u0014\u0018-\\3uKJ\u001c\bc\u0001\u00182sA\u0011!(P\u0007\u0002w)\u0011AHB\u0001\u0007Y&t\u0017\r\\4\n\u0005yZ$A\u0002,fGR|'\u000fC\u0003A\u0001\u0011\u0005\u0011)\u0001\u0004=S:LGO\u0010\u000b\u0005\u0005\u0012+e\t\u0006\u0002 \u0007\")qg\u0010a\u0001q!)\u0011e\u0010a\u0001G!)qe\u0010a\u0001Q!)Af\u0010a\u0001[!9\u0001\n\u0001b\u0001\n#J\u0015a\u00013j[V\t!\n\u0005\u0002\u0011\u0017&\u0011A*\u0005\u0002\u0004\u0013:$\bB\u0002(\u0001A\u0003%!*\u0001\u0003eS6\u0004\u0003b\u0002)\u0001\u0005\u0004%I!S\u0001\f]Vlg)Z1ukJ,7\u000f\u0003\u0004S\u0001\u0001\u0006IAS\u0001\r]Vlg)Z1ukJ,7\u000f\t\u0005\b)\u0002\u0011\r\u0011\"\u0003V\u0003\u0015\u0019\u0018nZ7b+\u0005A\u0003BB,\u0001A\u0003%\u0001&\u0001\u0004tS\u001el\u0017\r\t\u0005\b3\u0002\u0011\r\u0011\"\u0003V\u0003%Ig\u000e^3sG\u0016\u0004H\u000f\u0003\u0004\\\u0001\u0001\u0006I\u0001K\u0001\u000bS:$XM]2faR\u0004\u0003\"B/\u0001\t\u0003q\u0016aA1eIR\u0011qd\u0018\u0005\u0006Ar\u0003\r!G\u0001\tS:\u001cH/\u00198dK\u0002")
/* loaded from: input_file:org/apache/spark/ml/optim/aggregator/HuberAggregator.class */
public class HuberAggregator implements DifferentiableLossAggregator<Instance, HuberAggregator> {
    private final boolean fitIntercept;
    private final double epsilon;
    private final Broadcast<double[]> bcFeaturesStd;
    private final Broadcast<Vector> bcParameters;
    private final int dim;
    private final int numFeatures;
    private final double sigma;
    private final double intercept;
    private double weightSum;
    private double lossSum;
    private double[] gradientSumArray;
    private volatile boolean bitmap$0;

    /* JADX WARN: Type inference failed for: r0v1, types: [org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator, org.apache.spark.ml.optim.aggregator.HuberAggregator] */
    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public HuberAggregator merge(HuberAggregator huberAggregator) {
        ?? merge;
        merge = merge(huberAggregator);
        return merge;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public Vector gradient() {
        Vector gradient;
        gradient = gradient();
        return gradient;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weight() {
        double weight;
        weight = weight();
        return weight;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double loss() {
        double loss;
        loss = loss();
        return loss;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double weightSum() {
        return this.weightSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void weightSum_$eq(double d) {
        this.weightSum = d;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double lossSum() {
        return this.lossSum;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public void lossSum_$eq(double d) {
        this.lossSum = d;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v0 */
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Throwable] */
    /* JADX WARN: Type inference failed for: r0v8, types: [org.apache.spark.ml.optim.aggregator.HuberAggregator] */
    private double[] gradientSumArray$lzycompute() {
        double[] gradientSumArray;
        ?? r0 = this;
        synchronized (r0) {
            if (!this.bitmap$0) {
                gradientSumArray = gradientSumArray();
                this.gradientSumArray = gradientSumArray;
                r0 = this;
                r0.bitmap$0 = true;
            }
        }
        return this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public double[] gradientSumArray() {
        return !this.bitmap$0 ? gradientSumArray$lzycompute() : this.gradientSumArray;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public int dim() {
        return this.dim;
    }

    private int numFeatures() {
        return this.numFeatures;
    }

    private double sigma() {
        return this.sigma;
    }

    private double intercept() {
        return this.intercept;
    }

    @Override // org.apache.spark.ml.optim.aggregator.DifferentiableLossAggregator
    public HuberAggregator add(Instance instance) {
        if (instance == null) {
            throw new MatchError(instance);
        }
        double label = instance.label();
        double weight = instance.weight();
        Vector features = instance.features();
        Predef$.MODULE$.require(numFeatures() == features.size(), () -> {
            return new StringBuilder(43).append("Dimensions mismatch when adding new sample.").append(new StringBuilder(21).append(" Expecting ").append(this.numFeatures()).append(" but got ").append(features.size()).append(".").toString()).toString();
        });
        Predef$.MODULE$.require(weight >= 0.0d, () -> {
            return new StringBuilder(34).append("instance weight, ").append(weight).append(" has to be >= 0.0").toString();
        });
        if (weight == 0.0d) {
            return this;
        }
        double[] dArr = (double[]) this.bcFeaturesStd.value();
        double[] dArr2 = (double[]) new ArrayOps.ofDouble(Predef$.MODULE$.doubleArrayOps(((Vector) this.bcParameters.value()).toArray())).slice(0, numFeatures());
        double[] gradientSumArray = gradientSumArray();
        DoubleRef create = DoubleRef.create(0.0d);
        features.foreachActive((i, d) -> {
            if (dArr[i] == 0.0d || d == 0.0d) {
                return;
            }
            create.elem += dArr2[i] * (d / dArr[i]);
        });
        if (this.fitIntercept) {
            create.elem += intercept();
        }
        double d2 = label - create.elem;
        if (package$.MODULE$.abs(d2) <= sigma() * this.epsilon) {
            lossSum_$eq(lossSum() + (0.5d * weight * (sigma() + (package$.MODULE$.pow(d2, 2.0d) / sigma()))));
            double sigma = d2 / sigma();
            features.foreachActive((i2, d3) -> {
                if (dArr[i2] == 0.0d || d3 == 0.0d) {
                    return;
                }
                gradientSumArray[i2] = gradientSumArray[i2] + ((-1.0d) * weight * sigma * (d3 / dArr[i2]));
            });
            if (this.fitIntercept) {
                int dim = dim() - 2;
                gradientSumArray[dim] = gradientSumArray[dim] + ((-1.0d) * weight * sigma);
            }
            int dim2 = dim() - 1;
            gradientSumArray[dim2] = gradientSumArray[dim2] + (0.5d * weight * (1.0d - package$.MODULE$.pow(sigma, 2.0d)));
        } else {
            double d4 = d2 >= ((double) 0) ? -1.0d : 1.0d;
            lossSum_$eq(lossSum() + (0.5d * weight * ((sigma() + ((2.0d * this.epsilon) * package$.MODULE$.abs(d2))) - ((sigma() * this.epsilon) * this.epsilon))));
            features.foreachActive((i3, d5) -> {
                if (dArr[i3] == 0.0d || d5 == 0.0d) {
                    return;
                }
                gradientSumArray[i3] = gradientSumArray[i3] + (weight * d4 * this.epsilon * (d5 / dArr[i3]));
            });
            if (this.fitIntercept) {
                int dim3 = dim() - 2;
                gradientSumArray[dim3] = gradientSumArray[dim3] + (weight * d4 * this.epsilon);
            }
            int dim4 = dim() - 1;
            gradientSumArray[dim4] = gradientSumArray[dim4] + (0.5d * weight * (1.0d - (this.epsilon * this.epsilon)));
        }
        weightSum_$eq(weightSum() + weight);
        return this;
    }

    public HuberAggregator(boolean z, double d, Broadcast<double[]> broadcast, Broadcast<Vector> broadcast2) {
        this.fitIntercept = z;
        this.epsilon = d;
        this.bcFeaturesStd = broadcast;
        this.bcParameters = broadcast2;
        DifferentiableLossAggregator.$init$(this);
        this.dim = ((Vector) broadcast2.value()).size();
        this.numFeatures = z ? dim() - 2 : dim() - 1;
        this.sigma = ((Vector) broadcast2.value()).apply(dim() - 1);
        this.intercept = z ? ((Vector) broadcast2.value()).apply(dim() - 2) : 0.0d;
    }
}
