package org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.Function;
import org.apache.lucene.util.RamUsageEstimator;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TargetType;
import org.elasticsearch.xpack.core.ml.inference.utils.Statistics;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;

/* loaded from: input_file:lib/x-pack-core-7.17.14.jar:org/elasticsearch/xpack/core/ml/inference/trainedmodel/ensemble/WeightedMode.class */
public class WeightedMode implements StrictlyParsedOutputAggregator, LenientlyParsedOutputAggregator {
    private static final long SHALLOW_SIZE = RamUsageEstimator.shallowSizeOfInstance(WeightedMode.class);
    public static final ParseField NAME = new ParseField(org.elasticsearch.client.ml.inference.trainedmodel.ensemble.WeightedMode.NAME, new String[0]);
    public static final ParseField WEIGHTS = new ParseField("weights", new String[0]);
    public static final ParseField NUM_CLASSES = new ParseField("num_classes", new String[0]);
    private static final ConstructingObjectParser<WeightedMode, Void> LENIENT_PARSER = createParser(true);
    private static final ConstructingObjectParser<WeightedMode, Void> STRICT_PARSER = createParser(false);
    private final double[] weights;
    private final int numClasses;

    private static ConstructingObjectParser<WeightedMode, Void> createParser(boolean z) {
        ConstructingObjectParser<WeightedMode, Void> constructingObjectParser = new ConstructingObjectParser<>(NAME.getPreferredName(), z, (Function<Object[], WeightedMode>) objArr -> {
            return new WeightedMode((Integer) objArr[0], (List<Double>) objArr[1]);
        });
        constructingObjectParser.declareInt(ConstructingObjectParser.constructorArg(), NUM_CLASSES);
        constructingObjectParser.declareDoubleArray(ConstructingObjectParser.optionalConstructorArg(), WEIGHTS);
        return constructingObjectParser;
    }

    public static WeightedMode fromXContentStrict(XContentParser xContentParser) {
        return STRICT_PARSER.apply2(xContentParser, (XContentParser) null);
    }

    public static WeightedMode fromXContentLenient(XContentParser xContentParser) {
        return LENIENT_PARSER.apply2(xContentParser, (XContentParser) null);
    }

    WeightedMode(int i) {
        this(Integer.valueOf(i), (List<Double>) null);
    }

    private WeightedMode(Integer num, List<Double> list) {
        this(list == null ? null : list.stream().mapToDouble((v0) -> {
            return Double.valueOf(v0);
        }).toArray(), num);
    }

    public WeightedMode(double[] dArr, Integer num) {
        this.weights = dArr;
        this.numClasses = ((Integer) ExceptionsHelper.requireNonNull(num, NUM_CLASSES)).intValue();
        if (this.numClasses <= 1) {
            throw new IllegalArgumentException("[" + NUM_CLASSES.getPreferredName() + "] must be greater than 1.");
        }
    }

    public WeightedMode(StreamInput streamInput) throws IOException {
        if (streamInput.readBoolean()) {
            this.weights = streamInput.readDoubleArray();
        } else {
            this.weights = null;
        }
        this.numClasses = streamInput.readVInt();
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public Integer expectedValueSize() {
        if (this.weights == null) {
            return null;
        }
        return Integer.valueOf(this.weights.length);
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public double[] processValues(double[][] dArr) {
        Objects.requireNonNull(dArr, "values must not be null");
        if (this.weights != null && dArr.length != this.weights.length) {
            throw new IllegalArgumentException("values must be the same length as weights.");
        }
        if (dArr[0].length > 1) {
            double[] dArr2 = new double[dArr[0].length];
            for (int i = 0; i < dArr.length; i++) {
                double[] dArr3 = dArr[i];
                double d = this.weights == null ? 1.0d : this.weights[i];
                for (int i2 = 0; i2 < dArr3.length; i2++) {
                    if (i2 >= dArr2.length) {
                        throw new IllegalArgumentException("value entries must have the same dimensions");
                    }
                    int i3 = i2;
                    dArr2[i3] = dArr2[i3] + (dArr3[i2] * d);
                }
            }
            return Statistics.softMax(dArr2);
        }
        ArrayList arrayList = new ArrayList();
        int i4 = 0;
        for (double[] dArr4 : dArr) {
            if (dArr4.length != 1) {
                throw new IllegalArgumentException("value entries must have the same dimensions");
            }
            if (Double.isNaN(dArr4[0]) || Double.isInfinite(dArr4[0]) || dArr4[0] < TextFieldMapper.Defaults.FIELDDATA_MIN_FREQUENCY || dArr4[0] != Math.rint(dArr4[0])) {
                throw new IllegalArgumentException("values must be whole, non-infinite, and positive");
            }
            int intValue = Double.valueOf(dArr4[0]).intValue();
            arrayList.add(Integer.valueOf(intValue));
            if (intValue > i4) {
                i4 = intValue;
            }
        }
        if (i4 >= this.numClasses) {
            throw new IllegalArgumentException("values contain entries larger than expected max of [" + (this.numClasses - 1) + "]");
        }
        double[] array = Collections.nCopies(this.numClasses, Double.valueOf(Double.NEGATIVE_INFINITY)).stream().mapToDouble((v0) -> {
            return v0.doubleValue();
        }).toArray();
        for (int i5 = 0; i5 < arrayList.size(); i5++) {
            double d2 = this.weights == null ? 1.0d : this.weights[i5];
            int intValue2 = ((Integer) arrayList.get(i5)).intValue();
            array[intValue2] = array[intValue2] == Double.NEGATIVE_INFINITY ? d2 : array[intValue2] + d2;
        }
        return Statistics.softMax(array);
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public double aggregate(double[] dArr) {
        Objects.requireNonNull(dArr, "values must not be null");
        int i = 0;
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            if (dArr[i2] > d) {
                d = dArr[i2];
                i = i2;
            }
        }
        return i;
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator, org.elasticsearch.xpack.core.ml.utils.NamedXContentObject
    public String getName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.xpack.core.ml.inference.trainedmodel.ensemble.OutputAggregator
    public boolean compatibleWith(TargetType targetType) {
        return targetType.equals(TargetType.CLASSIFICATION);
    }

    @Override // org.elasticsearch.common.io.stream.NamedWriteable
    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    @Override // org.elasticsearch.common.io.stream.Writeable
    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeBoolean(this.weights != null);
        if (this.weights != null) {
            streamOutput.writeDoubleArray(this.weights);
        }
        streamOutput.writeVInt(this.numClasses);
    }

    @Override // org.elasticsearch.xcontent.ToXContent
    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        if (this.weights != null) {
            xContentBuilder.field(WEIGHTS.getPreferredName(), this.weights);
        }
        xContentBuilder.field(NUM_CLASSES.getPreferredName(), this.numClasses);
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        WeightedMode weightedMode = (WeightedMode) obj;
        return Arrays.equals(this.weights, weightedMode.weights) && this.numClasses == weightedMode.numClasses;
    }

    public int hashCode() {
        return Objects.hash(Integer.valueOf(Arrays.hashCode(this.weights)), Integer.valueOf(this.numClasses));
    }

    @Override // org.apache.lucene.util.Accountable
    public long ramBytesUsed() {
        return SHALLOW_SIZE + (this.weights == null ? 0L : RamUsageEstimator.sizeOf(this.weights));
    }
}
