package io.improbable.keanu.util;

import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.vertices.SaveVertexParam;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.BooleanIfVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.AndBinaryVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.BooleanBinaryOpVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.OrBinaryVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.EqualsVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.GreaterThanOrEqualVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.GreaterThanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.LessThanOrEqualVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.LessThanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.operators.binary.compare.NotEqualsVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.DoubleIfVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.AdditionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DifferenceVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.DivisionVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.binary.MultiplicationVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.IntegerIfVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.operators.binary.IntegerAdditionVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.operators.binary.IntegerDifferenceVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.operators.binary.IntegerMultiplicationVertex;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

public class DescriptionCreator {

    private Map<Class, String> delimiters;
    private static final String thisVertex = "This Vertex";
    private static final String nullString = "Null";

    public DescriptionCreator() {
        delimiters = new HashMap<>();
        delimiters.put(AdditionVertex.class, " + ");
        delimiters.put(IntegerAdditionVertex.class, " + ");
        delimiters.put(DifferenceVertex.class, " - ");
        delimiters.put(IntegerDifferenceVertex.class, " - ");
        delimiters.put(MultiplicationVertex.class, " * ");
        delimiters.put(IntegerMultiplicationVertex.class, " * ");
        delimiters.put(DivisionVertex.class, " / ");
        delimiters.put(AndBinaryVertex.class, " && ");
        delimiters.put(EqualsVertex.class, " == ");
        delimiters.put(GreaterThanOrEqualVertex.class, " >= ");
        delimiters.put(GreaterThanVertex.class, " > ");
        delimiters.put(LessThanOrEqualVertex.class, " <= ");
        delimiters.put(LessThanVertex.class, " < ");
        delimiters.put(NotEqualsVertex.class, " != ");
        delimiters.put(OrBinaryVertex.class, " || ");
    }

    /**
     * This method constructs an equation to describe how a vertex is calculated.
     * The description is generated by recursively stepping up through the BayesNet and generating descriptions.
     * Descriptions of common vertices will use infix operators.
     * Descriptions will not recurse any further than labelled vertices.
     * <p>
     * It is suggested that to use this feature, you label as many vertices as possible to avoid complex descriptions.
     *
     * @param vertex The vertex to create the description of
     * @return An String equation describing how this vertex is calculated.<br>
     * E.g. "This Vertex = that + (three * Const(4))"
     */
    public String createDescription(Vertex<?> vertex) {
        if (vertex == null) {
            return thisVertex + " = " + nullString;
        }
        Collection<Vertex> parents = vertex.getParents();

        if (parents.size() == 0) {
            StringBuilder builder = new StringBuilder(thisVertex);
            builder.append(" = ");
            builder.append(getLeafDescription(vertex));
            return builder.toString();
        }

        String thisLabel = vertex.getLabel() != null ? vertex.getLabel().toString() : thisVertex;

        return thisLabel + " = " + generateDescription(vertex, false, false);
    }

    private String generateDescription(Vertex<?> vertex, boolean allowLabels, boolean includeBrackets) {
        if (vertex == null) {
            return nullString;
        }
        if (allowLabels && vertex.getLabel() != null) {
            return vertex.getLabel().toString();
        }

        Collection<Vertex> parents = vertex.getParents();

        if (parents.size() == 0) {
            return getLeafDescription(vertex);
        }

        Optional<String> irregularDescription = checkForIrregularExpressions(vertex, includeBrackets);
        if (irregularDescription.isPresent()) {
            return irregularDescription.get();
        }

        if (delimiters.containsKey(vertex.getClass())) {
            CharSequence delimiter = delimiters.get(vertex.getClass());
            return getDelimiterVertexDescription(vertex, delimiter, includeBrackets);
        }

        Optional<String> saveLoadDescription = tryCreateDescriptionFromSaveLoadAnnotations(vertex, includeBrackets);
        return saveLoadDescription.orElseGet(() -> getDelimiterVertexDescription(vertex, ", ", includeBrackets));

    }

    private String getDelimiterVertexDescription(Vertex<?> vertex, CharSequence delimiter, boolean includeBrackets) {
        Stream<String> parentStream = vertex
            .getParents()
            .stream()
            .map(parent -> generateDescription(parent, true, true));

        String[] parentStrings = parentStream.toArray(String[]::new);

        StringBuilder builder = new StringBuilder();

        if (includeBrackets) {
            builder.append("(");
        }
        builder.append(String.join(delimiter, parentStrings));
        if (includeBrackets) {
            builder.append(")");
        }
        return builder.toString();
    }

    private static String getLeafDescription(Vertex<?> vertex) {
        if (vertex.getLabel() != null) {
            return vertex.getLabel().toString();
        }
        StringBuilder builder = new StringBuilder();
        Optional<String> scalarValue = tryGetScalarValue(vertex);
        if (scalarValue.isPresent()) {
            builder.append("Const(");
            builder.append(scalarValue.get());
            builder.append(")");
        } else {
            builder.append(vertex.getClass().getSimpleName());
            builder.append(" with shape: ");
            builder.append(Arrays.toString(vertex.getShape()));
        }
        return builder
            .toString();
    }

    private static Optional<String> tryGetScalarValue(Vertex<?> vertex) {
        Tensor tensor = (Tensor) vertex.getValue();
        if (tensor.isScalar()) {
            return Optional.of(tensor.scalar().toString());
        }
        return Optional.empty();
    }

    private Optional<String> checkForIrregularExpressions(Vertex<?> vertex, boolean includeBrackets) {
        if (vertex instanceof BooleanIfVertex || vertex instanceof DoubleIfVertex || vertex instanceof IntegerIfVertex) {
            Optional<String> description = createIfStringDescription(vertex, includeBrackets);
            if (description.isPresent()) {
                return Optional.of(description.get());
            }
        } else if (vertex instanceof BooleanBinaryOpVertex) {
            String booleanBinaryDescription = createBooleanBinaryOpDescription(
                (BooleanBinaryOpVertex) vertex,
                delimiters.getOrDefault(vertex.getClass(), ", "),
                includeBrackets);
            return Optional.of(booleanBinaryDescription);
        }

        return Optional.empty();
    }

    private Optional<String> createIfStringDescription(Vertex<?> vertex, boolean includeBrackets) {
        BooleanVertex predicate;
        Vertex<?> thn;
        Vertex<?> els;

        if (vertex instanceof IntegerIfVertex) {
            predicate = ((IntegerIfVertex) vertex).getPredicate();
            thn = ((IntegerIfVertex) vertex).getThn();
            els = ((IntegerIfVertex) vertex).getEls();
        } else if (vertex instanceof BooleanIfVertex) {
            predicate = ((BooleanIfVertex) vertex).getPredicate();
            thn = ((BooleanIfVertex) vertex).getThn();
            els = ((BooleanIfVertex) vertex).getEls();
        } else if (vertex instanceof DoubleIfVertex) {
            predicate = ((DoubleIfVertex) vertex).getPredicate();
            thn = ((DoubleIfVertex) vertex).getThn();
            els = ((DoubleIfVertex) vertex).getEls();
        } else {
            return Optional.empty();
        }

        StringBuilder builder = new StringBuilder();

        if (includeBrackets) {
            builder.append("(");
        }
        builder.append(generateDescription(predicate, true, true));
        builder.append(" ? ");
        builder.append(generateDescription(thn, true, true));
        builder.append(" : ");
        builder.append(generateDescription(els, true, true));
        if (includeBrackets) {
            builder.append(")");
        }

        return Optional.of(builder.toString());
    }

    private Optional<String> tryCreateDescriptionFromSaveLoadAnnotations(Vertex vertex, boolean includeBrackets) {
        Method[] classSaveLoadMethods = Arrays.stream(vertex.getClass().getMethods())
            .filter(method -> method.isAnnotationPresent(SaveVertexParam.class))
            .toArray(Method[]::new);

        if (classSaveLoadMethods.length == 0) {
            return Optional.empty();
        }

        StringBuilder builder = new StringBuilder(includeBrackets ? "(" : "");
        String vertexName = vertex.getClass().getSimpleName();
        builder.append(vertexName);
        builder.append("(");

        try {
            for (Method method : classSaveLoadMethods) {
                String paramName = method.getAnnotation(SaveVertexParam.class).value();
                Vertex paramVertex = (Vertex) method.invoke(vertex);
                appendParamToBuilder(paramName, paramVertex, builder);
            }
        } catch (IllegalAccessException | InvocationTargetException e) {
            return Optional.empty();
        }

        builder.delete(builder.length() - 2, builder.length());

        return Optional.of(builder.append(")")
            .append(includeBrackets ? ")" : "")
            .toString());
    }

    private void appendParamToBuilder(String paramName, Vertex<?> paramVertex, StringBuilder builder) {
        builder.append(paramName).append("=");
        builder.append(generateDescription(paramVertex, true, true));
        builder.append(", ");
    }

    private String createBooleanBinaryOpDescription(BooleanBinaryOpVertex<?, ?> opVertex, String operation, boolean includeBrackets) {
        StringBuilder builder = new StringBuilder();

        if (includeBrackets) {
            builder.append("(");
        }

        builder.append(generateDescription(opVertex.getLeft(), true, includeBrackets));
        builder.append(operation);
        builder.append(generateDescription(opVertex.getRight(), true, includeBrackets));

        if (includeBrackets) {
            builder.append(")");
        }
        return builder.toString();
    }
}
