/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval.curves;

import com.google.common.base.Preconditions;
import java.beans.ConstructorProperties;
import java.util.Arrays;
import org.deeplearning4j.eval.curves.BaseCurve;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class PrecisionRecallCurve
extends BaseCurve {
    private double[] threshold;
    private double[] precision;
    private double[] recall;
    private int[] tpCount;
    private int[] fpCount;
    private int[] fnCount;
    private int totalCount;
    private Double area;

    public PrecisionRecallCurve(@JsonProperty(value="threshold") double[] threshold, @JsonProperty(value="precision") double[] precision, @JsonProperty(value="recall") double[] recall, @JsonProperty(value="tpCount") int[] tpCount, @JsonProperty(value="fpCount") int[] fpCount, @JsonProperty(value="fnCount") int[] fnCount, @JsonProperty(value="totalCount") int totalCount) {
        this.threshold = threshold;
        this.precision = precision;
        this.recall = recall;
        this.tpCount = tpCount;
        this.fpCount = fpCount;
        this.fnCount = fnCount;
        this.totalCount = totalCount;
    }

    @Override
    public int numPoints() {
        return this.threshold.length;
    }

    @Override
    public double[] getX() {
        return this.recall;
    }

    @Override
    public double[] getY() {
        return this.precision;
    }

    @Override
    public String getTitle() {
        return "Precision-Recall Curve (Area=" + this.format(this.calculateAUPRC(), 4) + ")";
    }

    public double getThreshold(int i) {
        Preconditions.checkArgument((i >= 0 && i < this.threshold.length ? 1 : 0) != 0, (Object)("Invalid index: " + i));
        return this.threshold[i];
    }

    public double getPrecision(int i) {
        Preconditions.checkArgument((i >= 0 && i < this.precision.length ? 1 : 0) != 0, (Object)("Invalid index: " + i));
        return this.precision[i];
    }

    public double getRecall(int i) {
        Preconditions.checkArgument((i >= 0 && i < this.recall.length ? 1 : 0) != 0, (Object)("Invalid index: " + i));
        return this.recall[i];
    }

    public double calculateAUPRC() {
        if (this.area != null) {
            return this.area;
        }
        this.area = this.calculateArea();
        return this.area;
    }

    public Point getPointAtThreshold(double threshold) {
        int idx = Arrays.binarySearch(this.threshold, threshold);
        if (idx < 0) {
            idx = -idx - 1;
        }
        double thr = this.threshold[idx];
        double pr = this.precision[idx];
        double rec = this.recall[idx];
        return new Point(idx, thr, pr, rec);
    }

    public Point getPointAtPrecision(double precision) {
        int i;
        for (i = 0; i < this.precision.length; ++i) {
            if (!(this.precision[i] >= precision)) continue;
            return new Point(i, this.threshold[i], this.precision[i], this.recall[i]);
        }
        i = this.threshold.length - 1;
        return new Point(i, this.threshold[i], this.precision[i], this.recall[i]);
    }

    public Point getPointAtRecall(double recall) {
        for (int i = this.recall.length - 1; i >= 0; --i) {
            if (!(this.recall[i] >= recall)) continue;
            return new Point(i, this.threshold[i], this.precision[i], this.recall[i]);
        }
        return new Point(0, this.threshold[0], this.precision[0], this.recall[0]);
    }

    public Confusion getConfusionMatrixAtThreshold(double threshold) {
        Point p = this.getPointAtThreshold(threshold);
        int idx = p.idx;
        int tn = this.totalCount - (this.tpCount[idx] + this.fpCount[idx] + this.fnCount[idx]);
        return new Confusion(p, this.tpCount[idx], this.fpCount[idx], this.fnCount[idx], tn);
    }

    public Confusion getConfusionMatrixAtPoint(int point) {
        return this.getConfusionMatrixAtThreshold(this.threshold[point]);
    }

    public static PrecisionRecallCurve fromJson(String json) {
        return PrecisionRecallCurve.fromJson(json, PrecisionRecallCurve.class);
    }

    public static PrecisionRecallCurve fromYaml(String yaml) {
        return PrecisionRecallCurve.fromYaml(yaml, PrecisionRecallCurve.class);
    }

    public double[] getThreshold() {
        return this.threshold;
    }

    public double[] getPrecision() {
        return this.precision;
    }

    public double[] getRecall() {
        return this.recall;
    }

    public int[] getTpCount() {
        return this.tpCount;
    }

    public int[] getFpCount() {
        return this.fpCount;
    }

    public int[] getFnCount() {
        return this.fnCount;
    }

    public int getTotalCount() {
        return this.totalCount;
    }

    public Double getArea() {
        return this.area;
    }

    public void setThreshold(double[] threshold) {
        this.threshold = threshold;
    }

    public void setPrecision(double[] precision) {
        this.precision = precision;
    }

    public void setRecall(double[] recall) {
        this.recall = recall;
    }

    public void setTpCount(int[] tpCount) {
        this.tpCount = tpCount;
    }

    public void setFpCount(int[] fpCount) {
        this.fpCount = fpCount;
    }

    public void setFnCount(int[] fnCount) {
        this.fnCount = fnCount;
    }

    public void setTotalCount(int totalCount) {
        this.totalCount = totalCount;
    }

    public void setArea(Double area) {
        this.area = area;
    }

    public String toString() {
        return "PrecisionRecallCurve(threshold=" + Arrays.toString(this.getThreshold()) + ", precision=" + Arrays.toString(this.getPrecision()) + ", recall=" + Arrays.toString(this.getRecall()) + ", tpCount=" + Arrays.toString(this.getTpCount()) + ", fpCount=" + Arrays.toString(this.getFpCount()) + ", fnCount=" + Arrays.toString(this.getFnCount()) + ", totalCount=" + this.getTotalCount() + ", area=" + this.getArea() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof PrecisionRecallCurve)) {
            return false;
        }
        PrecisionRecallCurve other = (PrecisionRecallCurve)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!Arrays.equals(this.getThreshold(), other.getThreshold())) {
            return false;
        }
        if (!Arrays.equals(this.getPrecision(), other.getPrecision())) {
            return false;
        }
        if (!Arrays.equals(this.getRecall(), other.getRecall())) {
            return false;
        }
        if (!Arrays.equals(this.getTpCount(), other.getTpCount())) {
            return false;
        }
        if (!Arrays.equals(this.getFpCount(), other.getFpCount())) {
            return false;
        }
        if (!Arrays.equals(this.getFnCount(), other.getFnCount())) {
            return false;
        }
        return this.getTotalCount() == other.getTotalCount();
    }

    protected boolean canEqual(Object other) {
        return other instanceof PrecisionRecallCurve;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + Arrays.hashCode(this.getThreshold());
        result = result * 59 + Arrays.hashCode(this.getPrecision());
        result = result * 59 + Arrays.hashCode(this.getRecall());
        result = result * 59 + Arrays.hashCode(this.getTpCount());
        result = result * 59 + Arrays.hashCode(this.getFpCount());
        result = result * 59 + Arrays.hashCode(this.getFnCount());
        result = result * 59 + this.getTotalCount();
        return result;
    }

    public static class Confusion {
        private final Point point;
        private final int tpCount;
        private final int fpCount;
        private final int fnCount;
        private final int tnCount;

        @ConstructorProperties(value={"point", "tpCount", "fpCount", "fnCount", "tnCount"})
        public Confusion(Point point, int tpCount, int fpCount, int fnCount, int tnCount) {
            this.point = point;
            this.tpCount = tpCount;
            this.fpCount = fpCount;
            this.fnCount = fnCount;
            this.tnCount = tnCount;
        }

        public Point getPoint() {
            return this.point;
        }

        public int getTpCount() {
            return this.tpCount;
        }

        public int getFpCount() {
            return this.fpCount;
        }

        public int getFnCount() {
            return this.fnCount;
        }

        public int getTnCount() {
            return this.tnCount;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Confusion)) {
                return false;
            }
            Confusion other = (Confusion)o;
            if (!other.canEqual(this)) {
                return false;
            }
            Point this$point = this.getPoint();
            Point other$point = other.getPoint();
            if (this$point == null ? other$point != null : !((Object)this$point).equals(other$point)) {
                return false;
            }
            if (this.getTpCount() != other.getTpCount()) {
                return false;
            }
            if (this.getFpCount() != other.getFpCount()) {
                return false;
            }
            if (this.getFnCount() != other.getFnCount()) {
                return false;
            }
            return this.getTnCount() == other.getTnCount();
        }

        protected boolean canEqual(Object other) {
            return other instanceof Confusion;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            Point $point = this.getPoint();
            result = result * 59 + ($point == null ? 43 : ((Object)$point).hashCode());
            result = result * 59 + this.getTpCount();
            result = result * 59 + this.getFpCount();
            result = result * 59 + this.getFnCount();
            result = result * 59 + this.getTnCount();
            return result;
        }

        public String toString() {
            return "PrecisionRecallCurve.Confusion(point=" + this.getPoint() + ", tpCount=" + this.getTpCount() + ", fpCount=" + this.getFpCount() + ", fnCount=" + this.getFnCount() + ", tnCount=" + this.getTnCount() + ")";
        }
    }

    public static class Point {
        private final int idx;
        private final double threshold;
        private final double precision;
        private final double recall;

        @ConstructorProperties(value={"idx", "threshold", "precision", "recall"})
        public Point(int idx, double threshold, double precision, double recall) {
            this.idx = idx;
            this.threshold = threshold;
            this.precision = precision;
            this.recall = recall;
        }

        public int getIdx() {
            return this.idx;
        }

        public double getThreshold() {
            return this.threshold;
        }

        public double getPrecision() {
            return this.precision;
        }

        public double getRecall() {
            return this.recall;
        }

        public boolean equals(Object o) {
            if (o == this) {
                return true;
            }
            if (!(o instanceof Point)) {
                return false;
            }
            Point other = (Point)o;
            if (!other.canEqual(this)) {
                return false;
            }
            if (this.getIdx() != other.getIdx()) {
                return false;
            }
            if (Double.compare(this.getThreshold(), other.getThreshold()) != 0) {
                return false;
            }
            if (Double.compare(this.getPrecision(), other.getPrecision()) != 0) {
                return false;
            }
            return Double.compare(this.getRecall(), other.getRecall()) == 0;
        }

        protected boolean canEqual(Object other) {
            return other instanceof Point;
        }

        public int hashCode() {
            int PRIME = 59;
            int result = 1;
            result = result * 59 + this.getIdx();
            long $threshold = Double.doubleToLongBits(this.getThreshold());
            result = result * 59 + (int)($threshold >>> 32 ^ $threshold);
            long $precision = Double.doubleToLongBits(this.getPrecision());
            result = result * 59 + (int)($precision >>> 32 ^ $precision);
            long $recall = Double.doubleToLongBits(this.getRecall());
            result = result * 59 + (int)($recall >>> 32 ^ $recall);
            return result;
        }

        public String toString() {
            return "PrecisionRecallCurve.Point(idx=" + this.getIdx() + ", threshold=" + this.getThreshold() + ", precision=" + this.getPrecision() + ", recall=" + this.getRecall() + ")";
        }
    }
}

