package org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import java.util.stream.IntStream;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilder;
import org.elasticsearch.search.aggregations.AggregationBuilders;
import org.elasticsearch.search.aggregations.Aggregations;
import org.elasticsearch.search.aggregations.metrics.Percentiles;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult;
import org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.security.support.MetadataUtils;

/* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc.class */
public class AucRoc implements SoftClassificationMetric {
    public static final ParseField NAME;
    public static final ParseField INCLUDE_CURVE;
    public static final ConstructingObjectParser<AucRoc, Void> PARSER;
    private static final String PERCENTILES = "percentiles";
    private final boolean includeCurve;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc$AucRocPoint.class */
    public static final class AucRocPoint implements Comparable<AucRocPoint>, ToXContentObject, Writeable {
        double tpr;
        double fpr;
        double threshold;

        private AucRocPoint(double d, double d2, double d3) {
            this.tpr = d;
            this.fpr = d2;
            this.threshold = d3;
        }

        private AucRocPoint(StreamInput streamInput) throws IOException {
            this.tpr = streamInput.readDouble();
            this.fpr = streamInput.readDouble();
            this.threshold = streamInput.readDouble();
        }

        @Override // java.lang.Comparable
        public int compareTo(AucRocPoint aucRocPoint) {
            return Comparator.comparingDouble(aucRocPoint2 -> {
                return aucRocPoint2.threshold;
            }).reversed().thenComparing(aucRocPoint3 -> {
                return Double.valueOf(aucRocPoint3.fpr);
            }).thenComparing(aucRocPoint4 -> {
                return Double.valueOf(aucRocPoint4.tpr);
            }).compare(this, aucRocPoint);
        }

        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeDouble(this.tpr);
            streamOutput.writeDouble(this.fpr);
            streamOutput.writeDouble(this.threshold);
        }

        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field("tpr", this.tpr);
            xContentBuilder.field("fpr", this.fpr);
            xContentBuilder.field("threshold", this.threshold);
            xContentBuilder.endObject();
            return xContentBuilder;
        }

        public String toString() {
            return Strings.toString(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc$RateThresholdCurve.class */
    public static class RateThresholdCurve {
        private final double[] percentiles;
        private final boolean isTp;

        private RateThresholdCurve(double[] dArr, boolean z) {
            this.percentiles = dArr;
            this.isTp = z;
        }

        private double getRate(int i) {
            return 1.0d - (0.01d * (i + 1));
        }

        private double getThreshold(int i) {
            return this.percentiles[i];
        }

        private double interpolateRate(double d) {
            int binarySearch = Arrays.binarySearch(this.percentiles, d);
            if (binarySearch >= 0) {
                return getRate(binarySearch);
            }
            int i = (binarySearch * (-1)) - 1;
            int i2 = i - 1;
            if (i >= this.percentiles.length) {
                return 0.0d;
            }
            if (i2 < 0) {
                return 1.0d;
            }
            double rate = getRate(i);
            return AucRoc.interpolate(d, this.percentiles[i2], getRate(i2), this.percentiles[i], rate);
        }

        /* JADX INFO: Access modifiers changed from: private */
        public List<AucRocPoint> scanPoints(RateThresholdCurve rateThresholdCurve) {
            ArrayList arrayList = new ArrayList();
            for (int i = 0; i < this.percentiles.length; i++) {
                double rate = getRate(i);
                double threshold = getThreshold(i);
                double interpolateRate = rateThresholdCurve.interpolateRate(threshold);
                arrayList.add(this.isTp ? new AucRocPoint(rate, interpolateRate, threshold) : new AucRocPoint(interpolateRate, rate, threshold));
            }
            return arrayList;
        }
    }

    /* loaded from: input_file:org/elasticsearch/xpack/core/ml/dataframe/evaluation/softclassification/AucRoc$Result.class */
    public static class Result implements EvaluationMetricResult {
        private final double score;
        private final List<AucRocPoint> curve;

        public Result(double d, List<AucRocPoint> list) {
            this.score = d;
            this.curve = (List) Objects.requireNonNull(list);
        }

        public Result(StreamInput streamInput) throws IOException {
            this.score = streamInput.readDouble();
            this.curve = streamInput.readList(streamInput2 -> {
                return new AucRocPoint(streamInput2);
            });
        }

        public String getWriteableName() {
            return AucRoc.NAME.getPreferredName();
        }

        @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.EvaluationMetricResult
        public String getName() {
            return AucRoc.NAME.getPreferredName();
        }

        public void writeTo(StreamOutput streamOutput) throws IOException {
            streamOutput.writeDouble(this.score);
            streamOutput.writeList(this.curve);
        }

        public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
            xContentBuilder.startObject();
            xContentBuilder.field("score", this.score);
            if (!this.curve.isEmpty()) {
                xContentBuilder.field("curve", this.curve);
            }
            xContentBuilder.endObject();
            return xContentBuilder;
        }
    }

    public static AucRoc fromXContent(XContentParser xContentParser) {
        return (AucRoc) PARSER.apply(xContentParser, (Object) null);
    }

    public AucRoc(Boolean bool) {
        this.includeCurve = bool == null ? false : bool.booleanValue();
    }

    public AucRoc(StreamInput streamInput) throws IOException {
        this.includeCurve = streamInput.readBoolean();
    }

    public String getWriteableName() {
        return NAME.getPreferredName();
    }

    public void writeTo(StreamOutput streamOutput) throws IOException {
        streamOutput.writeBoolean(this.includeCurve);
    }

    public XContentBuilder toXContent(XContentBuilder xContentBuilder, ToXContent.Params params) throws IOException {
        xContentBuilder.startObject();
        xContentBuilder.field(INCLUDE_CURVE.getPreferredName(), this.includeCurve);
        xContentBuilder.endObject();
        return xContentBuilder;
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric
    public String getMetricName() {
        return NAME.getPreferredName();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return Objects.equals(Boolean.valueOf(this.includeCurve), Boolean.valueOf(((AucRoc) obj).includeCurve));
    }

    public int hashCode() {
        return Objects.hash(Boolean.valueOf(this.includeCurve));
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric
    public List<AggregationBuilder> aggs(String str, List<SoftClassificationMetric.ClassInfo> list) {
        double[] array = IntStream.range(1, 100).mapToDouble(i -> {
            return i;
        }).toArray();
        ArrayList arrayList = new ArrayList();
        for (SoftClassificationMetric.ClassInfo classInfo : list) {
            AbstractAggregationBuilder subAggregation = AggregationBuilders.filter(evaluatedLabelAggName(classInfo), classInfo.matchingQuery()).subAggregation(AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(array));
            AbstractAggregationBuilder subAggregation2 = AggregationBuilders.filter(restLabelsAggName(classInfo), QueryBuilders.boolQuery().mustNot(classInfo.matchingQuery())).subAggregation(AggregationBuilders.percentiles(PERCENTILES).field(classInfo.getProbabilityField()).percentiles(array));
            arrayList.add(subAggregation);
            arrayList.add(subAggregation2);
        }
        return arrayList;
    }

    private String evaluatedLabelAggName(SoftClassificationMetric.ClassInfo classInfo) {
        return getMetricName() + MetadataUtils.RESERVED_PREFIX + classInfo.getName();
    }

    private String restLabelsAggName(SoftClassificationMetric.ClassInfo classInfo) {
        return getMetricName() + "_non_" + classInfo.getName();
    }

    @Override // org.elasticsearch.xpack.core.ml.dataframe.evaluation.softclassification.SoftClassificationMetric
    public EvaluationMetricResult evaluate(SoftClassificationMetric.ClassInfo classInfo, Aggregations aggregations) {
        List<AucRocPoint> buildAucRocCurve = buildAucRocCurve(percentilesArray(aggregations.get(evaluatedLabelAggName(classInfo)).getAggregations().get(PERCENTILES), "[" + getMetricName() + "] requires at least one actual_field to have the value [" + classInfo.getName() + "]"), percentilesArray(aggregations.get(restLabelsAggName(classInfo)).getAggregations().get(PERCENTILES), "[" + getMetricName() + "] requires at least one actual_field to have a different value than [" + classInfo.getName() + "]"));
        return new Result(calculateAucScore(buildAucRocCurve), this.includeCurve ? buildAucRocCurve : Collections.emptyList());
    }

    private static double[] percentilesArray(Percentiles percentiles, String str) {
        double[] dArr = new double[99];
        percentiles.forEach(percentile -> {
            if (Double.isNaN(percentile.getValue())) {
                throw ExceptionsHelper.badRequestException(str, new Object[0]);
            }
            dArr[((int) percentile.getPercent()) - 1] = percentile.getValue();
        });
        return dArr;
    }

    static List<AucRocPoint> buildAucRocCurve(double[] dArr, double[] dArr2) {
        if (!$assertionsDisabled && dArr.length != dArr2.length) {
            throw new AssertionError();
        }
        if (!$assertionsDisabled && dArr.length != 99) {
            throw new AssertionError();
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(new AucRocPoint(0.0d, 0.0d, 1.0d));
        arrayList.add(new AucRocPoint(1.0d, 1.0d, 0.0d));
        RateThresholdCurve rateThresholdCurve = new RateThresholdCurve(dArr, true);
        RateThresholdCurve rateThresholdCurve2 = new RateThresholdCurve(dArr2, false);
        arrayList.addAll(rateThresholdCurve.scanPoints(rateThresholdCurve2));
        arrayList.addAll(rateThresholdCurve2.scanPoints(rateThresholdCurve));
        Collections.sort(arrayList);
        return arrayList;
    }

    static double calculateAucScore(List<AucRocPoint> list) {
        double d = 0.0d;
        for (int i = 1; i < list.size(); i++) {
            AucRocPoint aucRocPoint = list.get(i - 1);
            AucRocPoint aucRocPoint2 = list.get(i);
            d += ((aucRocPoint2.fpr - aucRocPoint.fpr) * (aucRocPoint2.tpr + aucRocPoint.tpr)) / 2.0d;
        }
        return d;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double interpolate(double d, double d2, double d3, double d4, double d5) {
        return d3 + (((d - d2) * (d5 - d3)) / (d4 - d2));
    }

    static {
        $assertionsDisabled = !AucRoc.class.desiredAssertionStatus();
        NAME = new ParseField("auc_roc", new String[0]);
        INCLUDE_CURVE = new ParseField("include_curve", new String[0]);
        PARSER = new ConstructingObjectParser<>(NAME.getPreferredName(), objArr -> {
            return new AucRoc((Boolean) objArr[0]);
        });
        PARSER.declareBoolean(ConstructingObjectParser.optionalConstructorArg(), INCLUDE_CURVE);
    }
}
