/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.templating.loop;

import com.google.common.collect.ImmutableMap;
import io.improbable.keanu.templating.Sequence;
import io.improbable.keanu.templating.SequenceBuilder;
import io.improbable.keanu.templating.SequenceItem;
import io.improbable.keanu.templating.loop.Loop;
import io.improbable.keanu.templating.loop.LoopConstructionException;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexDictionary;
import io.improbable.keanu.vertices.VertexLabel;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.bool.nonprobabilistic.BooleanProxyVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.DoubleIfVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.DoubleProxyVertex;
import io.improbable.keanu.vertices.generic.nonprobabilistic.If;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.function.Supplier;

public class LoopBuilder {
    private final VertexDictionary initialState;
    private ImmutableMap.Builder<VertexLabel, VertexLabel> customMappings = ImmutableMap.builder();
    private int maxLoopCount = 100;
    private boolean throwWhenMaxCountIsReached = true;

    <V extends Vertex<?>> LoopBuilder(VertexDictionary initialState) {
        this.initialState = initialState;
    }

    public LoopBuilder withMaxIterations(int maxCount) {
        this.maxLoopCount = maxCount;
        return this;
    }

    public LoopBuilder doNotThrowWhenMaxCountIsReached() {
        this.throwWhenMaxCountIsReached = false;
        return this;
    }

    public LoopBuilder mapping(VertexLabel proxyLabel, VertexLabel proxysParentLabel) {
        this.customMappings.put((Object)proxyLabel, (Object)proxysParentLabel);
        return this;
    }

    public LoopBuilder2 iterateWhile(Supplier<BooleanVertex> conditionSupplier) {
        return this.iterateWhile((SequenceItem item) -> (BooleanVertex)conditionSupplier.get());
    }

    public LoopBuilder2 iterateWhile(Function<SequenceItem, BooleanVertex> conditionFunction) {
        return new LoopBuilder2(this.initialState, conditionFunction, (Map<VertexLabel, VertexLabel>)this.customMappings.build(), this.maxLoopCount, this.throwWhenMaxCountIsReached);
    }

    public class LoopBuilder2 {
        private final VertexDictionary initialState;
        private final Function<SequenceItem, BooleanVertex> conditionFunction;
        private final Map<VertexLabel, VertexLabel> customMappings;
        private final int maxLoopCount;
        private final boolean throwWhenMaxCountIsReached;
        private final VertexLabel VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL = new VertexLabel("loop_value_out_when_always_true");
        private final VertexLabel LOOP_LABEL = new VertexLabel("loop");

        LoopBuilder2(VertexDictionary initialState, Function<SequenceItem, BooleanVertex> conditionFunction, Map<VertexLabel, VertexLabel> customMappings, int maxLoopCount, boolean throwWhenMaxCountIsReached) {
            this.initialState = this.setInitialState(initialState);
            this.conditionFunction = conditionFunction;
            this.customMappings = customMappings;
            this.maxLoopCount = maxLoopCount;
            this.throwWhenMaxCountIsReached = throwWhenMaxCountIsReached;
        }

        private VertexDictionary setInitialState(VertexDictionary initialState) {
            DoubleProxyVertex valueOutWhenAlwaysTrue;
            try {
                Object outputVertex = initialState.get(Loop.VALUE_OUT_LABEL);
                valueOutWhenAlwaysTrue = new DoubleProxyVertex(this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL.withExtraNamespace("Loop_" + this.hashCode()));
                valueOutWhenAlwaysTrue.setParents(new Vertex[]{outputVertex});
            }
            catch (NoSuchElementException e) {
                throw new LoopConstructionException("You must pass in a base case, i.e. a vertex labeled as Loop.VALUE_OUT_LABEL", e);
            }
            catch (IllegalArgumentException e) {
                throw new LoopConstructionException("You must pass in only one vertex labeled as Loop.VALUE_OUT_LABEL", e);
            }
            return initialState.withExtraEntries((Map<VertexLabel, Vertex<?>>)ImmutableMap.of((Object)this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL, (Object)valueOutWhenAlwaysTrue, (Object)this.LOOP_LABEL, (Object)ConstantVertex.of(true)));
        }

        public Loop apply(Function<DoubleVertex, DoubleVertex> iterationFunction) {
            return this.apply((SequenceItem item, DoubleVertex valueIn) -> (DoubleVertex)iterationFunction.apply((DoubleVertex)valueIn));
        }

        public Loop apply(BiFunction<SequenceItem, DoubleVertex, DoubleVertex> iterationFunction) {
            Sequence sequence = new SequenceBuilder().withInitialState(this.initialState).withTransitionMapping(this.customMappings).count(this.maxLoopCount).withFactory(item -> {
                DoubleProxyVertex valueInWhenAlwaysTrue = item.addDoubleProxyFor(this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL);
                BooleanProxyVertex stillLooping = item.addBooleanProxyFor(this.LOOP_LABEL);
                DoubleProxyVertex valueIn = item.addDoubleProxyFor(Loop.VALUE_OUT_LABEL);
                BooleanVertex condition = this.conditionFunction.apply((SequenceItem)item);
                item.add(Loop.CONDITION_LABEL, condition);
                DoubleVertex iterationResult = (DoubleVertex)iterationFunction.apply((SequenceItem)item, valueInWhenAlwaysTrue);
                BooleanVertex loopAgain = stillLooping.and(condition);
                DoubleIfVertex result = If.isTrue(loopAgain).then(iterationResult).orElse(valueIn);
                item.add(this.VALUE_OUT_WHEN_ALWAYS_TRUE_LABEL, iterationResult);
                item.add(this.LOOP_LABEL, loopAgain);
                item.add(Loop.VALUE_OUT_LABEL, result);
            }).build();
            return new Loop(sequence, this.throwWhenMaxCountIsReached);
        }
    }
}

