package org.allenai.ml.objective;

import java.util.ArrayList;
import java.util.List;
import org.allenai.ml.linalg.DenseVector;
import org.allenai.ml.linalg.Vector;
import org.allenai.ml.optimize.GradientFn;
import org.allenai.ml.util.Parallel;

/* loaded from: input_file:org/allenai/ml/objective/BatchObjectiveFn.class */
public class BatchObjectiveFn<T> implements GradientFn {
    private final List<T> data;
    private final long dimension;
    private final ExampleObjectiveFn<T> exampleObjectiveFn;
    private final Parallel.MROpts mapReduceOpts;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.allenai.ml.objective.BatchObjectiveFn$1ObjectiveStats, reason: invalid class name */
    /* loaded from: input_file:org/allenai/ml/objective/BatchObjectiveFn$1ObjectiveStats.class */
    public class C1ObjectiveStats {
        double value;
        Vector gradient;
        final /* synthetic */ Vector val$weights;

        C1ObjectiveStats(Vector vector) {
            this.val$weights = vector;
            this.gradient = DenseVector.of(this.val$weights.dimension());
        }
    }

    public BatchObjectiveFn(List<T> list, ExampleObjectiveFn<T> exampleObjectiveFn, long j, Parallel.MROpts mROpts) {
        this.data = new ArrayList(list);
        this.exampleObjectiveFn = exampleObjectiveFn;
        this.dimension = j;
        this.mapReduceOpts = mROpts;
    }

    public void shutdown() {
        Parallel.shutdownExecutor(this.mapReduceOpts.executorService, Long.MAX_VALUE);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.allenai.ml.optimize.GradientFn, java.util.function.Function
    public GradientFn.Result apply(Vector vector) {
        final Vector copy = vector.copy();
        C1ObjectiveStats c1ObjectiveStats = (C1ObjectiveStats) Parallel.mapReduce(this.data, new Parallel.MapReduceDriver<T, C1ObjectiveStats>() { // from class: org.allenai.ml.objective.BatchObjectiveFn.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.allenai.ml.util.Parallel.MapReduceDriver
            public C1ObjectiveStats newData() {
                return new C1ObjectiveStats(copy);
            }

            /* renamed from: update, reason: avoid collision after fix types in other method */
            public void update2(C1ObjectiveStats c1ObjectiveStats2, T t) {
                c1ObjectiveStats2.value += BatchObjectiveFn.this.exampleObjectiveFn.evaluate(t, copy, c1ObjectiveStats2.gradient);
            }

            @Override // org.allenai.ml.util.Parallel.MapReduceDriver
            public void merge(C1ObjectiveStats c1ObjectiveStats2, C1ObjectiveStats c1ObjectiveStats3) {
                c1ObjectiveStats2.value += c1ObjectiveStats3.value;
                c1ObjectiveStats2.gradient.addInPlace(1.0d, c1ObjectiveStats3.gradient);
            }

            @Override // org.allenai.ml.util.Parallel.MapReduceDriver
            public /* bridge */ /* synthetic */ void update(C1ObjectiveStats c1ObjectiveStats2, Object obj) {
                update2(c1ObjectiveStats2, (C1ObjectiveStats) obj);
            }
        }, this.mapReduceOpts);
        return GradientFn.Result.of(-c1ObjectiveStats.value, c1ObjectiveStats.gradient.scale(-1.0d));
    }

    @Override // org.allenai.ml.optimize.GradientFn
    public long dimension() {
        return this.dimension;
    }
}
