/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.templating.loop;

import com.google.common.collect.ImmutableMap;
import io.improbable.keanu.templating.Sequence;
import io.improbable.keanu.templating.SequenceBuilder;
import io.improbable.keanu.templating.SequenceItem;
import io.improbable.keanu.templating.loop.LoopBuilder;
import io.improbable.keanu.templating.loop.LoopConstructionException;
import io.improbable.keanu.templating.loop.LoopDidNotTerminateException;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.vertices.SimpleVertexDictionary;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexDictionary;
import io.improbable.keanu.vertices.VertexLabel;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Loop {
    private static final Logger log = LoggerFactory.getLogger(Loop.class);
    public static final VertexLabel VALUE_OUT_LABEL = new VertexLabel("loop_value_out");
    public static final VertexLabel CONDITION_LABEL = new VertexLabel("loop_condition");
    public static final VertexLabel VALUE_IN_LABEL = SequenceBuilder.proxyLabelFor(VALUE_OUT_LABEL);
    public static final VertexLabel STILL_LOOPING_LABEL = SequenceBuilder.proxyLabelFor(new VertexLabel("loop"));
    public static final int DEFAULT_MAX_COUNT = 100;
    private final Sequence sequence;
    private final boolean throwWhenMaxCountIsReached;

    Loop(Sequence sequence, boolean throwWhenMaxCountIsReached) {
        this.sequence = sequence;
        this.throwWhenMaxCountIsReached = throwWhenMaxCountIsReached;
    }

    public Sequence getSequence() {
        return this.sequence;
    }

    public static <V extends Vertex<?>> LoopBuilder withInitialConditions(V first, V ... others) {
        Map map = Loop.buildMapForBaseCase(first, others);
        return Loop.withInitialConditions(SimpleVertexDictionary.backedBy(map));
    }

    public static LoopBuilder withInitialConditions(VertexDictionary initialState) {
        return new LoopBuilder(initialState);
    }

    private static <V extends Vertex<?>> Map<VertexLabel, Vertex<?>> buildMapForBaseCase(V first, V[] others) {
        ImmutableMap.Builder baseCaseMap = ImmutableMap.builder();
        baseCaseMap.put((Object)VALUE_OUT_LABEL, first);
        for (V vertex : others) {
            VertexLabel label = ((Vertex)vertex).getLabel();
            if (label == null) {
                label = new VertexLabel(String.format("base_case_vertex_%d", ((Vertex)vertex).hashCode()));
            }
            baseCaseMap.put((Object)label, vertex);
        }
        try {
            return baseCaseMap.build();
        }
        catch (IllegalArgumentException e) {
            throw new LoopConstructionException("Duplicate label found in base case");
        }
    }

    public <V extends Vertex<?>> V getOutput() throws LoopDidNotTerminateException {
        SequenceItem finalItem = this.sequence.getLastItem();
        this.checkIfMaxCountHasBeenReached(finalItem);
        return finalItem.get(VALUE_OUT_LABEL);
    }

    private void checkIfMaxCountHasBeenReached(SequenceItem item) throws LoopDidNotTerminateException {
        BooleanVertex stillLooping = (BooleanVertex)item.get(STILL_LOOPING_LABEL);
        if (!((BooleanTensor)stillLooping.getValue()).allFalse()) {
            String errorMessage = "Loop has exceeded its max count " + this.sequence.size();
            if (this.throwWhenMaxCountIsReached) {
                throw new LoopDidNotTerminateException(errorMessage);
            }
            log.info(errorMessage);
        }
    }
}

