package io.improbable.keanu.util;

import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.BooleanIfVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.AndBinaryVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.BooleanBinaryOpVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.OrBinaryVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.EqualsVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.GreaterThanOrEqualVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.GreaterThanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.LessThanOrEqualVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.LessThanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.NotEqualsVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.DoubleIfVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.IntegerIfVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.operators.binary.IntegerAdditionVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.operators.binary.IntegerDifferenceVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.operators.binary.IntegerMultiplicationVertex;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

/* loaded from: input_file:io/improbable/keanu/util/DescriptionCreator.class */
public class DescriptionCreator {
    private Map<Class, String> delimiters = new HashMap();
    private static final String thisVertex = "This Vertex";
    private static final String nullString = "Null";

    public DescriptionCreator() {
        this.delimiters.put(AdditionVertex.class, " + ");
        this.delimiters.put(IntegerAdditionVertex.class, " + ");
        this.delimiters.put(DifferenceVertex.class, " - ");
        this.delimiters.put(IntegerDifferenceVertex.class, " - ");
        this.delimiters.put(MultiplicationVertex.class, " * ");
        this.delimiters.put(IntegerMultiplicationVertex.class, " * ");
        this.delimiters.put(DivisionVertex.class, " / ");
        this.delimiters.put(AndBinaryVertex.class, " && ");
        this.delimiters.put(EqualsVertex.class, " == ");
        this.delimiters.put(GreaterThanOrEqualVertex.class, " >= ");
        this.delimiters.put(GreaterThanVertex.class, " > ");
        this.delimiters.put(LessThanOrEqualVertex.class, " <= ");
        this.delimiters.put(LessThanVertex.class, " < ");
        this.delimiters.put(NotEqualsVertex.class, " != ");
        this.delimiters.put(OrBinaryVertex.class, " || ");
    }

    public String createDescription(Vertex<?> vertex) {
        if (vertex == null) {
            return "This Vertex = Null";
        }
        if (vertex.getParents().size() != 0) {
            return (vertex.getLabel() != null ? vertex.getLabel().toString() : thisVertex) + " = " + generateDescription(vertex, false, false);
        }
        return thisVertex + " = " + getLeafDescription(vertex);
    }

    private String generateDescription(Vertex<?> vertex, boolean z, boolean z2) {
        if (vertex == null) {
            return nullString;
        }
        if (z && vertex.getLabel() != null) {
            return vertex.getLabel().toString();
        }
        if (vertex.getParents().size() == 0) {
            return getLeafDescription(vertex);
        }
        Optional<String> checkForIrregularExpressions = checkForIrregularExpressions(vertex, z2);
        return checkForIrregularExpressions.isPresent() ? checkForIrregularExpressions.get() : this.delimiters.containsKey(vertex.getClass()) ? getDelimiterVertexDescription(vertex, this.delimiters.get(vertex.getClass()), z2) : tryCreateDescriptionFromSaveLoadAnnotations(vertex, z2).orElseGet(() -> {
            return getDelimiterVertexDescription(vertex, ", ", z2);
        });
    }

    private String getDelimiterVertexDescription(Vertex<?> vertex, CharSequence charSequence, boolean z) {
        String[] strArr = (String[]) vertex.getParents().stream().map(vertex2 -> {
            return generateDescription(vertex2, true, true);
        }).toArray(i -> {
            return new String[i];
        });
        StringBuilder sb = new StringBuilder();
        if (z) {
            sb.append("(");
        }
        sb.append(String.join(charSequence, strArr));
        if (z) {
            sb.append(")");
        }
        return sb.toString();
    }

    private static String getLeafDescription(Vertex<?> vertex) {
        if (vertex.getLabel() != null) {
            return vertex.getLabel().toString();
        }
        StringBuilder sb = new StringBuilder();
        Optional<String> tryGetScalarValue = tryGetScalarValue(vertex);
        if (tryGetScalarValue.isPresent()) {
            sb.append("Const(");
            sb.append(tryGetScalarValue.get());
            sb.append(")");
        } else {
            sb.append(vertex.getClass().getSimpleName());
            sb.append(" with shape: ");
            sb.append(Arrays.toString(vertex.getShape()));
        }
        return sb.toString();
    }

    private static Optional<String> tryGetScalarValue(Vertex<?> vertex) {
        Tensor tensor = (Tensor) vertex.getValue();
        return tensor.isScalar() ? Optional.of(tensor.scalar().toString()) : Optional.empty();
    }

    private Optional<String> checkForIrregularExpressions(Vertex<?> vertex, boolean z) {
        if ((vertex instanceof BooleanIfVertex) || (vertex instanceof DoubleIfVertex) || (vertex instanceof IntegerIfVertex)) {
            Optional<String> createIfStringDescription = createIfStringDescription(vertex, z);
            if (createIfStringDescription.isPresent()) {
                return Optional.of(createIfStringDescription.get());
            }
        } else if (vertex instanceof BooleanBinaryOpVertex) {
            return Optional.of(createBooleanBinaryOpDescription((BooleanBinaryOpVertex) vertex, this.delimiters.getOrDefault(vertex.getClass(), ", "), z));
        }
        return Optional.empty();
    }

    private Optional<String> createIfStringDescription(Vertex<?> vertex, boolean z) {
        BooleanVertex predicate;
        Vertex<?> thn;
        Vertex<?> els;
        if (vertex instanceof IntegerIfVertex) {
            predicate = ((IntegerIfVertex) vertex).getPredicate();
            thn = ((IntegerIfVertex) vertex).getThn();
            els = ((IntegerIfVertex) vertex).getEls();
        } else if (vertex instanceof BooleanIfVertex) {
            predicate = ((BooleanIfVertex) vertex).getPredicate();
            thn = ((BooleanIfVertex) vertex).getThn();
            els = ((BooleanIfVertex) vertex).getEls();
        } else {
            if (!(vertex instanceof DoubleIfVertex)) {
                return Optional.empty();
            }
            predicate = ((DoubleIfVertex) vertex).getPredicate();
            thn = ((DoubleIfVertex) vertex).getThn();
            els = ((DoubleIfVertex) vertex).getEls();
        }
        StringBuilder sb = new StringBuilder();
        if (z) {
            sb.append("(");
        }
        sb.append(generateDescription(predicate, true, true));
        sb.append(" ? ");
        sb.append(generateDescription(thn, true, true));
        sb.append(" : ");
        sb.append(generateDescription(els, true, true));
        if (z) {
            sb.append(")");
        }
        return Optional.of(sb.toString());
    }

    private Optional<String> tryCreateDescriptionFromSaveLoadAnnotations(Vertex vertex, boolean z) {
        Method[] methodArr = (Method[]) Arrays.stream(vertex.getClass().getMethods()).filter(method -> {
            return method.isAnnotationPresent(SaveVertexParam.class);
        }).toArray(i -> {
            return new Method[i];
        });
        if (methodArr.length == 0) {
            return Optional.empty();
        }
        StringBuilder sb = new StringBuilder(z ? "(" : "");
        sb.append(vertex.getClass().getSimpleName());
        sb.append("(");
        try {
            for (Method method2 : methodArr) {
                appendParamToBuilder(((SaveVertexParam) method2.getAnnotation(SaveVertexParam.class)).value(), (Vertex) method2.invoke(vertex, new Object[0]), sb);
            }
            sb.delete(sb.length() - 2, sb.length());
            return Optional.of(sb.append(")").append(z ? ")" : "").toString());
        } catch (IllegalAccessException | InvocationTargetException e) {
            return Optional.empty();
        }
    }

    private void appendParamToBuilder(String str, Vertex<?> vertex, StringBuilder sb) {
        sb.append(str).append("=");
        sb.append(generateDescription(vertex, true, true));
        sb.append(", ");
    }

    private String createBooleanBinaryOpDescription(BooleanBinaryOpVertex<?, ?> booleanBinaryOpVertex, String str, boolean z) {
        StringBuilder sb = new StringBuilder();
        if (z) {
            sb.append("(");
        }
        sb.append(generateDescription(booleanBinaryOpVertex.getLeft(), true, z));
        sb.append(str);
        sb.append(generateDescription(booleanBinaryOpVertex.getRight(), true, z));
        if (z) {
            sb.append(")");
        }
        return sb.toString();
    }
}
