/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.evaluation;

import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.EvaluationUtils;
import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.EvaluationBinary;
import org.nd4j.evaluation.classification.EvaluationCalibration;
import org.nd4j.evaluation.classification.ROC;
import org.nd4j.evaluation.classification.ROCBinary;
import org.nd4j.evaluation.classification.ROCMultiClass;
import org.nd4j.evaluation.regression.RegressionEvaluation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastTo;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.primitives.Triple;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonDeserializerAtomicDouble;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicBoolean;
import org.nd4j.linalg.primitives.serde.JsonSerializerAtomicDouble;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.jackson.annotation.JsonAutoDetect;
import org.nd4j.shade.jackson.core.JsonFactory;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.JsonDeserializer;
import org.nd4j.shade.jackson.databind.JsonSerializer;
import org.nd4j.shade.jackson.databind.MapperFeature;
import org.nd4j.shade.jackson.databind.Module;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.nd4j.shade.jackson.databind.SerializationFeature;
import org.nd4j.shade.jackson.databind.exc.InvalidTypeIdException;
import org.nd4j.shade.jackson.databind.module.SimpleModule;
import org.nd4j.shade.jackson.dataformat.yaml.YAMLFactory;

public abstract class BaseEvaluation<T extends BaseEvaluation>
implements IEvaluation<T> {
    private static ObjectMapper objectMapper = BaseEvaluation.configureMapper(new ObjectMapper());
    private static ObjectMapper yamlMapper = BaseEvaluation.configureMapper(new ObjectMapper((JsonFactory)new YAMLFactory()));

    private static ObjectMapper configureMapper(ObjectMapper ret) {
        ret.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
        ret.configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
        ret.configure(MapperFeature.SORT_PROPERTIES_ALPHABETICALLY, false);
        ret.enable(SerializationFeature.INDENT_OUTPUT);
        SimpleModule atomicModule = new SimpleModule();
        atomicModule.addSerializer(AtomicDouble.class, (JsonSerializer)new JsonSerializerAtomicDouble());
        atomicModule.addSerializer(AtomicBoolean.class, (JsonSerializer)new JsonSerializerAtomicBoolean());
        atomicModule.addDeserializer(AtomicDouble.class, (JsonDeserializer)new JsonDeserializerAtomicDouble());
        atomicModule.addDeserializer(AtomicBoolean.class, (JsonDeserializer)new JsonDeserializerAtomicBoolean());
        ret.registerModule((Module)atomicModule);
        ret.setVisibilityChecker(ret.getSerializationConfig().getDefaultVisibilityChecker().withFieldVisibility(JsonAutoDetect.Visibility.ANY).withGetterVisibility(JsonAutoDetect.Visibility.NONE).withSetterVisibility(JsonAutoDetect.Visibility.NONE).withCreatorVisibility(JsonAutoDetect.Visibility.ANY));
        return ret;
    }

    public static <T extends IEvaluation> T fromYaml(String yaml, Class<T> clazz) {
        try {
            return (T)((IEvaluation)yamlMapper.readValue(yaml, clazz));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public static <T extends IEvaluation> T fromJson(String json, Class<T> clazz) {
        try {
            return (T)((IEvaluation)objectMapper.readValue(json, clazz));
        }
        catch (InvalidTypeIdException e) {
            if (e.getMessage().contains("Could not resolve type id")) {
                try {
                    return BaseEvaluation.attempFromLegacyFromJson(json, e);
                }
                catch (Throwable t) {
                    throw new RuntimeException("Cannot deserialize from JSON - JSON is invalid?", t);
                }
            }
            throw new RuntimeException(e);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    protected static <T extends IEvaluation> T attempFromLegacyFromJson(String json, InvalidTypeIdException originalException) throws InvalidTypeIdException {
        if (json.contains("org.deeplearning4j.eval.Evaluation")) {
            String newJson = json.replaceAll("org.deeplearning4j.eval.Evaluation", "org.nd4j.evaluation.classification.Evaluation");
            return (T)BaseEvaluation.fromJson(newJson, Evaluation.class);
        }
        if (json.contains("org.deeplearning4j.eval.EvaluationBinary")) {
            String newJson = json.replaceAll("org.deeplearning4j.eval.EvaluationBinary", "org.nd4j.evaluation.classification.EvaluationBinary").replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
            return (T)BaseEvaluation.fromJson(newJson, EvaluationBinary.class);
        }
        if (json.contains("org.deeplearning4j.eval.EvaluationCalibration")) {
            String newJson = json.replaceAll("org.deeplearning4j.eval.EvaluationCalibration", "org.nd4j.evaluation.classification.EvaluationCalibration").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
            return (T)BaseEvaluation.fromJson(newJson, EvaluationCalibration.class);
        }
        if (json.contains("org.deeplearning4j.eval.ROCBinary")) {
            String newJson = json.replaceAll("org.deeplearning4j.eval.ROCBinary", "org.nd4j.evaluation.classification.ROCBinary").replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
            return (T)BaseEvaluation.fromJson(newJson, ROCBinary.class);
        }
        if (json.contains("org.deeplearning4j.eval.ROCMultiClass")) {
            String newJson = json.replaceAll("org.deeplearning4j.eval.ROCMultiClass", "org.nd4j.evaluation.classification.ROCMultiClass").replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
            return (T)BaseEvaluation.fromJson(newJson, ROCMultiClass.class);
        }
        if (json.contains("org.deeplearning4j.eval.ROC")) {
            String newJson = json.replaceAll("org.deeplearning4j.eval.ROC", "org.nd4j.evaluation.classification.ROC").replaceAll("org.deeplearning4j.eval.curves.", "org.nd4j.evaluation.curves.");
            return (T)BaseEvaluation.fromJson(newJson, ROC.class);
        }
        if (json.contains("org.deeplearning4j.eval.RegressionEvaluation")) {
            String newJson = json.replaceAll("org.deeplearning4j.eval.RegressionEvaluation", "org.nd4j.evaluation.regression.RegressionEvaluation");
            return (T)BaseEvaluation.fromJson(newJson, RegressionEvaluation.class);
        }
        throw originalException;
    }

    public static Triple<INDArray, INDArray, INDArray> reshapeAndExtractNotMasked(INDArray labels, INDArray predictions, INDArray mask, int axis) {
        if (labels.rank() == 2) {
            Preconditions.checkState((axis == 1 ? 1 : 0) != 0, (String)"Only axis=1 is supported 2d data - got axis=%s for labels array shape %ndShape", (Object)axis, (Object)labels);
            if (mask == null) {
                return new Triple((Object)labels, (Object)predictions, null);
            }
            if (mask.rank() == 1 || mask.isColumnVector()) {
                int notMaskedCount = mask.neq(0.0).castTo(DataType.INT).sumNumber().intValue();
                if (notMaskedCount == 0) {
                    return null;
                }
                if ((long)notMaskedCount == mask.length()) {
                    return new Triple((Object)labels, (Object)predictions, null);
                }
                int[] arr = mask.toIntVector();
                int[] idxs = new int[notMaskedCount];
                int pos = 0;
                for (int i = 0; i < arr.length; ++i) {
                    if (arr[i] == 0) continue;
                    idxs[pos++] = i;
                }
                INDArray retLabel = Nd4j.pullRows(labels, 1, idxs, 'c');
                INDArray retPredictions = Nd4j.pullRows(predictions, 1, idxs, 'c');
                return new Triple((Object)retLabel, (Object)retPredictions, null);
            }
            Preconditions.checkState((boolean)labels.equalShapes(mask), (String)"If a mask array is present for 2d data, it must either be a vector (column vector) or have shape equal to the labels (for per-output masking, when supported). Got labels shape %ndShape, mask shape %ndShape", (Object)labels, (Object)mask);
            return new Triple((Object)labels, (Object)predictions, (Object)mask);
        }
        if (labels.rank() == 3 || labels.rank() == 4 || labels.rank() == 5) {
            if (mask == null) {
                return BaseEvaluation.reshapeSameShapeTo2d(axis, labels, predictions, mask);
            }
            if (labels.rank() == 3) {
                if (mask.rank() == 2) {
                    Pair<INDArray, INDArray> p = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, mask);
                    if (p == null) {
                        return null;
                    }
                    return new Triple(p.getFirst(), p.getSecond(), null);
                }
                Preconditions.checkState((boolean)labels.equalShapes(mask), (String)"If a mask array is present for 3d data, it must either be 2d (shape [minibatch, sequenceLength]) or have shape equal to the labels (for per-output masking, when supported). Got labels shape %ndShape, mask shape %ndShape", (Object)labels, (Object)mask);
                return BaseEvaluation.reshapeSameShapeTo2d(axis, labels, predictions, mask);
            }
            if (labels.equalShapes(mask)) {
                return BaseEvaluation.reshapeSameShapeTo2d(axis, labels, predictions, mask);
            }
            if (mask.rank() == 1) {
                Preconditions.checkState((mask.length() == labels.size(0) ? 1 : 0) != 0, (String)"For rank 4 labels with shape %ndShape and 1d mask of shape %ndShape, the mask array length must equal labels dimension 0 size", (Object)labels, (Object)mask);
                long[] reshape = ArrayUtil.nTimes((long)labels.rank(), (long)1L);
                reshape[0] = mask.size(0);
                INDArray mReshape = mask.reshape(reshape);
                INDArray bMask = Nd4j.createUninitialized(mask.dataType(), labels.shape());
                BroadcastTo b = new BroadcastTo(mReshape, labels.shape(), bMask);
                Nd4j.exec(b);
                return BaseEvaluation.reshapeSameShapeTo2d(axis, labels, predictions, bMask);
            }
            if (mask.rank() == labels.rank() && Shape.areShapesBroadcastable(mask.shape(), labels.shape())) {
                INDArray bMask = Nd4j.createUninitialized(mask.dataType(), labels.shape());
                BroadcastTo b = new BroadcastTo(mask, labels.shape(), bMask);
                Nd4j.exec(b);
                return BaseEvaluation.reshapeSameShapeTo2d(axis, labels, predictions, bMask);
            }
            throw new UnsupportedOperationException("Evaluation case not supported: labels shape " + Arrays.toString(labels.shape()) + " with mask shape " + Arrays.toString(mask.shape()));
        }
        throw new IllegalStateException("Unknown array type passed to evaluation: labels array rank " + labels.rank() + " with shape " + labels.shapeInfoToString() + ". Labels and predictions must always be rank 2 or higher, with leading dimension being minibatch dimension");
    }

    private static Triple<INDArray, INDArray, INDArray> reshapeSameShapeTo2d(int axis, INDArray labels, INDArray predictions, INDArray mask) {
        int[] permuteDims = new int[labels.rank()];
        int j = 0;
        for (int i = 0; i < labels.rank(); ++i) {
            if (i == axis) continue;
            permuteDims[j++] = i;
        }
        permuteDims[j] = axis;
        long size0 = 1L;
        for (int i = 0; i < permuteDims.length - 1; ++i) {
            size0 *= labels.size(permuteDims[i]);
        }
        INDArray lOut = labels.permute(permuteDims).dup('c').reshape('c', size0, labels.size(axis));
        INDArray pOut = predictions.permute(permuteDims).dup('c').reshape('c', size0, labels.size(axis));
        INDArray mOut = mask == null ? null : mask.permute(permuteDims).dup('c').reshape('c', size0, labels.size(axis));
        return new Triple((Object)lOut, (Object)pOut, mOut);
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions) {
        this.eval(labels, networkPredictions, null, null);
    }

    @Override
    public void eval(@NonNull INDArray labels, @NonNull INDArray predictions, List<? extends Serializable> recordMetaData) {
        if (labels == null) {
            throw new NullPointerException("labels is marked @NonNull but is null");
        }
        if (predictions == null) {
            throw new NullPointerException("predictions is marked @NonNull but is null");
        }
        this.eval(labels, predictions, null, recordMetaData);
    }

    @Override
    public void eval(INDArray labels, INDArray networkPredictions, INDArray maskArray) {
        this.eval(labels, networkPredictions, maskArray, null);
    }

    @Override
    public void evalTimeSeries(INDArray labels, INDArray predicted) {
        this.evalTimeSeries(labels, predicted, null);
    }

    @Override
    public void evalTimeSeries(INDArray labels, INDArray predictions, INDArray labelsMask) {
        Pair<INDArray, INDArray> pair = EvaluationUtils.extractNonMaskedTimeSteps(labels, predictions, labelsMask);
        if (pair == null) {
            return;
        }
        INDArray labels2d = (INDArray)pair.getFirst();
        INDArray predicted2d = (INDArray)pair.getSecond();
        this.eval(labels2d, predicted2d);
    }

    @Override
    public String toJson() {
        try {
            return objectMapper.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public String toString() {
        return this.stats();
    }

    @Override
    public String toYaml() {
        try {
            return yamlMapper.writeValueAsString((Object)this);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BaseEvaluation)) {
            return false;
        }
        BaseEvaluation other = (BaseEvaluation)o;
        return other.canEqual(this);
    }

    protected boolean canEqual(Object other) {
        return other instanceof BaseEvaluation;
    }

    public int hashCode() {
        int result = 1;
        return result;
    }

    public static ObjectMapper getObjectMapper() {
        return objectMapper;
    }

    public static ObjectMapper getYamlMapper() {
        return yamlMapper;
    }
}

