package io.improbable.keanu.backend.keanu.compiled;

import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.backend.ComputableGraph;
import io.improbable.keanu.backend.ComputableGraphBuilder;
import io.improbable.keanu.backend.StringVariableReference;
import io.improbable.keanu.backend.keanu.compiled.KeanuVertexToTensorOpMapper;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.util.csv.Writer;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.ConstantBooleanVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.ConstantDoubleVertex;
import io.improbable.keanu.vertices.generic.GenericTensorVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.nonprobabilistic.ConstantIntegerVertex;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.joor.Reflect;

/* loaded from: input_file:io/improbable/keanu/backend/keanu/compiled/KeanuCompiledGraphBuilder.class */
public class KeanuCompiledGraphBuilder implements ComputableGraphBuilder<ComputableGraph> {
    private static final String PACKAGE = "io.improbable.keanu.backend.keanu";
    private static final String CLASS_NAME_PREFIX = "CompiledKeanuGraph";
    private int internalOpCount = 0;
    private final String className = CLASS_NAME_PREFIX + hashCode();
    private StringBuilder computeSourceBuilder = new StringBuilder();
    private StringBuilder instanceVariableBuilder = new StringBuilder();
    private StringBuilder constructorBuilder = new StringBuilder();
    private Map<VariableReference, KeanuCompiledVariable> lookup = new HashMap();
    private Map<VariableReference, Object> variableValues = new HashMap();
    private Map<VariableReference, Object> constantValues = new HashMap();
    private List<VariableReference> outputs = new ArrayList();

    private void startSource(StringBuilder sb) {
        sb.append("package io.improbable.keanu.backend.keanu;\n");
        sb.append(importString(Collection.class));
        sb.append(importString(Collections.class));
        sb.append(importString(HashMap.class));
        sb.append(importString(Map.class));
        sb.append(importString(VariableReference.class));
        sb.append(importString(DoubleTensor.class));
        sb.append(importString(IntegerTensor.class));
        sb.append(importString(BooleanTensor.class));
        append(sb, "public final class ", this.className, " implements java.util.function.Function<Map<String, ?>, Map<String, ?>> {\n");
    }

    private String importString(Class<?> cls) {
        return "import " + cls.getCanonicalName() + ";\n";
    }

    private void endSource(StringBuilder sb) {
        sb.append("Map<String, Object>  results = new HashMap<>();\n");
        sb.append(Writer.DEFAULT_LINE_END);
        for (VariableReference variableReference : this.outputs) {
            append(sb, "results.put(\"", variableReference.toStringReference(), "\", ", this.lookup.get(variableReference).getName(), ");\n");
        }
        sb.append("return results;\n");
        sb.append("}\n}\n");
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    public void createConstant(Vertex vertex) {
        String assigmentType = getAssigmentType(vertex);
        String stringReference = vertex.getReference().toStringReference();
        String sourceVariableName = toSourceVariableName(vertex.getReference());
        append(this.instanceVariableBuilder, "private final ", assigmentType, " ", sourceVariableName, ";\n");
        append(this.constructorBuilder, sourceVariableName, " = ", "(", assigmentType, ")", "constants.get(\"", stringReference, "\");\n");
        this.lookup.put(vertex.getReference(), new KeanuCompiledVariable(sourceVariableName, false));
        this.constantValues.put(vertex.getReference(), vertex.getValue());
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    public void createVariable(Vertex vertex) {
        String assigmentType = getAssigmentType(vertex);
        String sourceVariableName = toSourceVariableName(vertex.getReference());
        declareInput(assigmentType, sourceVariableName, vertex.getReference().toStringReference());
        this.lookup.put(vertex.getReference(), new KeanuCompiledVariable(sourceVariableName, false));
        this.variableValues.put(vertex.getReference(), vertex.getValue());
    }

    private void declareInput(String str, String str2, String str3) {
        append(this.computeSourceBuilder, "final ", str, " ", str2, " = (", str, ") inputs.get(\"", str3, "\");\n");
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    public void create(Vertex vertex) {
        if (isConstant(vertex)) {
            createConstant(vertex);
            return;
        }
        KeanuVertexToTensorOpMapper.OpMapper opMapperFor = KeanuVertexToTensorOpMapper.getOpMapperFor(vertex.getClass());
        String assigmentType = getAssigmentType(vertex);
        String sourceVariableName = toSourceVariableName(vertex.getReference());
        append(this.computeSourceBuilder, "final ", assigmentType, " ", sourceVariableName, " = ", opMapperFor.apply(vertex, this.lookup), ";\n");
        this.lookup.put(vertex.getReference(), new KeanuCompiledVariable(sourceVariableName, true));
    }

    private boolean isConstant(Vertex vertex) {
        return (vertex instanceof ConstantDoubleVertex) || (vertex instanceof ConstantIntegerVertex) || (vertex instanceof ConstantBooleanVertex);
    }

    private String getAssigmentType(Object obj) {
        return obj instanceof DoubleVertex ? DoubleTensor.class.getCanonicalName() : obj instanceof IntegerVertex ? IntegerTensor.class.getCanonicalName() : obj instanceof BooleanVertex ? BooleanTensor.class.getCanonicalName() : obj instanceof GenericTensorVertex ? Tensor.class.getCanonicalName() : Object.class.getCanonicalName();
    }

    private String toSourceVariableName(VariableReference variableReference) {
        return "v_" + variableReference.toStringReference();
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    public void registerOutput(VariableReference variableReference) {
        this.outputs.add(variableReference);
        this.lookup.get(variableReference).setMutable(false);
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    public Collection<VariableReference> getLatentVariables() {
        return this.variableValues.keySet();
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    public VariableReference add(VariableReference variableReference, VariableReference variableReference2) {
        String canonicalName = DoubleTensor.class.getCanonicalName();
        String name = this.lookup.get(variableReference).getName();
        String name2 = this.lookup.get(variableReference2).getName();
        String str = "vv_" + this.internalOpCount;
        this.internalOpCount++;
        append(this.computeSourceBuilder, "final ", canonicalName, " ", str, " = ", name, ".plus(", name2 + ");\n");
        StringVariableReference stringVariableReference = new StringVariableReference(str);
        this.lookup.put(stringVariableReference, new KeanuCompiledVariable(str, true));
        return stringVariableReference;
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    public void connect(Map<? extends Vertex<?>, ? extends Vertex<?>> map) {
        map.forEach((vertex, vertex2) -> {
            this.lookup.put(vertex2.getReference(), this.lookup.get(vertex.getReference()));
        });
    }

    public String getSource() {
        StringBuilder sb = new StringBuilder();
        startSource(sb);
        sb.append((CharSequence) this.instanceVariableBuilder);
        append(sb, "public ", this.className, "(final Map<String, ?> constants) {\n");
        sb.append((CharSequence) this.constructorBuilder);
        sb.append("}\n");
        sb.append("public Map<String, ?> apply(Map<String, ?> inputs) {\n");
        sb.append((CharSequence) this.computeSourceBuilder);
        endSource(sb);
        return sb.toString();
    }

    private void append(StringBuilder sb, String... strArr) {
        for (String str : strArr) {
            sb.append(str);
        }
    }

    @Override // io.improbable.keanu.backend.ComputableGraphBuilder
    /* renamed from: build, reason: merged with bridge method [inline-methods] */
    public ComputableGraph build2() {
        return compile(getSource());
    }

    private WrappedCompiledGraph compile(String str) {
        return new WrappedCompiledGraph((Function) Reflect.compile("io.improbable.keanu.backend.keanu." + this.className, str).create(new Object[]{(Map) this.constantValues.entrySet().stream().collect(Collectors.toMap(entry -> {
            return ((VariableReference) entry.getKey()).toStringReference();
        }, (v0) -> {
            return v0.getValue();
        }))}).get(), this.outputs);
    }
}
