package io.improbable.keanu.templating.loop;

import com.google.common.collect.ImmutableMap;
import io.improbable.keanu.templating.SequenceBuilder;
import io.improbable.keanu.templating.SequenceItem;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexDictionary;
import io.improbable.keanu.vertices.VertexLabel;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.BooleanProxyVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.DoubleIfVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.DoubleProxyVertex;
import io.improbable.keanu.vertices.generic.nonprobabilistic.If;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

/* loaded from: input_file:io/improbable/keanu/templating/loop/LoopBuilder.class */
public class LoopBuilder {
    private final VertexDictionary initialState;
    private ImmutableMap.Builder<VertexLabel, VertexLabel> customMappings = ImmutableMap.builder();
    private int maxLoopCount = 100;
    private boolean throwWhenMaxCountIsReached = true;

    /* loaded from: input_file:io/improbable/keanu/templating/loop/LoopBuilder$LoopBuilder2.class */
    public class LoopBuilder2 {
        private final VertexDictionary initialState;
        private final Function<SequenceItem, BooleanVertex> conditionFunction;
        private final Map<VertexLabel, VertexLabel> customMappings;
        private final int maxLoopCount;
        private final boolean throwWhenMaxCountIsReached;
        private final VertexLabel VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL = new VertexLabel("loop_value_out_when_always_true");
        private final VertexLabel LOOP_LABEL = new VertexLabel("loop");

        LoopBuilder2(VertexDictionary vertexDictionary, Function<SequenceItem, BooleanVertex> function, Map<VertexLabel, VertexLabel> map, int i, boolean z) {
            this.initialState = setInitialState(vertexDictionary);
            this.conditionFunction = function;
            this.customMappings = map;
            this.maxLoopCount = i;
            this.throwWhenMaxCountIsReached = z;
        }

        private VertexDictionary setInitialState(VertexDictionary vertexDictionary) {
            try {
                Vertex<?> vertex = vertexDictionary.get(Loop.VALUE_OUT_LABEL);
                DoubleProxyVertex doubleProxyVertex = new DoubleProxyVertex(this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL.withExtraNamespace("Loop_" + hashCode()));
                doubleProxyVertex.setParents(vertex);
                return vertexDictionary.withExtraEntries(ImmutableMap.of(this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL, doubleProxyVertex, this.LOOP_LABEL, ConstantVertex.of(true)));
            } catch (IllegalArgumentException e) {
                throw new LoopConstructionException("You must pass in only one vertex labeled as Loop.VALUE_OUT_LABEL", e);
            } catch (NoSuchElementException e2) {
                throw new LoopConstructionException("You must pass in a base case, i.e. a vertex labeled as Loop.VALUE_OUT_LABEL", e2);
            }
        }

        public Loop apply(Function<DoubleVertex, DoubleVertex> function) {
            return apply((sequenceItem, doubleVertex) -> {
                return (DoubleVertex) function.apply(doubleVertex);
            });
        }

        public Loop apply(BiFunction<SequenceItem, DoubleVertex, DoubleVertex> biFunction) {
            return new Loop(new SequenceBuilder().withInitialState(this.initialState).withTransitionMapping(this.customMappings).count(this.maxLoopCount).withFactory(sequenceItem -> {
                DoubleProxyVertex addDoubleProxyFor = sequenceItem.addDoubleProxyFor(this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL);
                BooleanProxyVertex addBooleanProxyFor = sequenceItem.addBooleanProxyFor(this.LOOP_LABEL);
                DoubleProxyVertex addDoubleProxyFor2 = sequenceItem.addDoubleProxyFor(Loop.VALUE_OUT_LABEL);
                BooleanVertex apply = this.conditionFunction.apply(sequenceItem);
                sequenceItem.add(Loop.CONDITION_LABEL, apply);
                DoubleVertex doubleVertex = (DoubleVertex) biFunction.apply(sequenceItem, addDoubleProxyFor);
                BooleanVertex and = addBooleanProxyFor.and(apply);
                DoubleIfVertex orElse = If.isTrue(and).then(doubleVertex).orElse(addDoubleProxyFor2);
                sequenceItem.add(this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL, doubleVertex);
                sequenceItem.add(this.LOOP_LABEL, and);
                sequenceItem.add(Loop.VALUE_OUT_LABEL, orElse);
            }).build(), this.throwWhenMaxCountIsReached);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public <V extends Vertex<?>> LoopBuilder(VertexDictionary vertexDictionary) {
        this.initialState = vertexDictionary;
    }

    public LoopBuilder withMaxIterations(int i) {
        this.maxLoopCount = i;
        return this;
    }

    public LoopBuilder doNotThrowWhenMaxCountIsReached() {
        this.throwWhenMaxCountIsReached = false;
        return this;
    }

    public LoopBuilder mapping(VertexLabel vertexLabel, VertexLabel vertexLabel2) {
        this.customMappings.put(vertexLabel, vertexLabel2);
        return this;
    }

    public LoopBuilder2 iterateWhile(Supplier<BooleanVertex> supplier) {
        return iterateWhile(sequenceItem -> {
            return (BooleanVertex) supplier.get();
        });
    }

    public LoopBuilder2 iterateWhile(Function<SequenceItem, BooleanVertex> function) {
        return new LoopBuilder2(this.initialState, function, this.customMappings.build(), this.maxLoopCount, this.throwWhenMaxCountIsReached);
    }
}
