/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.templating;

import com.google.common.collect.ImmutableMap;
import io.improbable.keanu.templating.Sequence;
import io.improbable.keanu.templating.SequenceConstructionException;
import io.improbable.keanu.templating.SequenceItem;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexDictionary;
import io.improbable.keanu.vertices.VertexLabel;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.Map;
import java.util.function.BiConsumer;
import java.util.function.Consumer;

public class SequenceBuilder<T> {
    private static final String PROXY_LABEL_MARKER = "proxy_for";
    private VertexDictionary initialState;
    private Map<VertexLabel, VertexLabel> transitionMapping = Collections.emptyMap();
    private String sequenceName;

    public static VertexLabel proxyLabelFor(VertexLabel label) {
        return label.withExtraNamespace(PROXY_LABEL_MARKER);
    }

    public SequenceBuilder<T> withInitialState(Vertex<?> vertex) {
        return this.withInitialState(VertexDictionary.of(vertex));
    }

    public SequenceBuilder<T> named(String sequenceName) {
        this.sequenceName = sequenceName;
        return this;
    }

    public SequenceBuilder<T> withInitialState(VertexLabel label, Vertex<?> vertex) {
        return this.withInitialState(VertexDictionary.backedBy(ImmutableMap.of((Object)label, vertex)));
    }

    public SequenceBuilder<T> withInitialState(VertexDictionary initialState) {
        this.initialState = initialState;
        return this;
    }

    public SequenceBuilder<T> withTransitionMapping(Map<VertexLabel, VertexLabel> transitionMapping) {
        this.transitionMapping = transitionMapping;
        return this;
    }

    public FromCount count(int count) {
        return new FromCount(count, this.initialState);
    }

    public FromIterator fromIterator(Iterator<T> iterator) {
        return new FromIterator(iterator, 0, this.initialState, this.transitionMapping);
    }

    public FromIterator fromIterator(Iterator<T> iterator, int sizeHint) {
        return new FromIterator(iterator, sizeHint, this.initialState, this.transitionMapping);
    }

    private void connectTransitionVariables(VertexDictionary candidateVertices, SequenceItem item, Map<VertexLabel, VertexLabel> transitionMapping) throws SequenceConstructionException {
        Collection<Vertex<?>> proxyVertices = item.getProxyVertices();
        for (Vertex<?> proxy : proxyVertices) {
            VertexLabel defaultParentLabel;
            VertexLabel proxyLabel = SequenceBuilder.getUnscopedLabel(proxy.getLabel(), this.sequenceName != null);
            VertexLabel parentLabel = transitionMapping.getOrDefault(proxyLabel, defaultParentLabel = this.getDefaultParentLabel(proxyLabel));
            if (parentLabel == null) {
                throw new SequenceConstructionException("Cannot find transition mapping for " + proxy.getLabel());
            }
            if (candidateVertices == null) {
                throw new IllegalArgumentException("You must provide a base case for the Transition Vertices - use withInitialState()");
            }
            Object parent = candidateVertices.get(parentLabel);
            if (parent == null) {
                throw new SequenceConstructionException("Cannot find VertexLabel " + parentLabel);
            }
            proxy.setParents(new Vertex[]{parent});
        }
    }

    private VertexLabel getDefaultParentLabel(VertexLabel proxyLabel) {
        String outerNamespace = proxyLabel.getOuterNamespace().orElse(null);
        if (PROXY_LABEL_MARKER.equals(outerNamespace)) {
            return proxyLabel.withoutOuterNamespace();
        }
        return null;
    }

    public static VertexLabel getUnscopedLabel(VertexLabel proxyLabel, boolean hasSequenceName) {
        if (hasSequenceName) {
            proxyLabel = proxyLabel.withoutOuterNamespace();
        }
        return proxyLabel.withoutOuterNamespace().withoutOuterNamespace();
    }

    public class FromCountFactories
    implements SequenceFactory {
        private Collection<Consumer<SequenceItem>> factories;
        private ItemCount count;

        private FromCountFactories(Collection<Consumer<SequenceItem>> factories, ItemCount count, VertexDictionary initialState, Map<VertexLabel, VertexLabel> transitionMapping) {
            this.factories = factories;
            this.count = count;
        }

        @Override
        public Sequence build() throws SequenceConstructionException {
            int uniqueSequenceIdentifier = this.factories.hashCode();
            Sequence sequence = new Sequence(this.count.getCount(), uniqueSequenceIdentifier, SequenceBuilder.this.sequenceName);
            VertexDictionary previousItem = SequenceBuilder.this.initialState;
            for (int i = 0; i < this.count.getCount(); ++i) {
                SequenceItem item = new SequenceItem(i, uniqueSequenceIdentifier, SequenceBuilder.this.sequenceName);
                this.factories.forEach(factory -> factory.accept(item));
                SequenceBuilder.this.connectTransitionVariables(previousItem, item, SequenceBuilder.this.transitionMapping);
                sequence.add(item);
                previousItem = item;
            }
            return sequence;
        }
    }

    public class FromDataFactories
    implements SequenceFactory {
        private Collection<BiConsumer<SequenceItem, T>> factories;
        private SequenceData<T> data;
        private int size;
        private final VertexDictionary initialState;

        private FromDataFactories(Collection<BiConsumer<SequenceItem, T>> factories, SequenceData<T> data, int size, VertexDictionary initialState) {
            this.factories = factories;
            this.data = data;
            this.size = size;
            this.initialState = initialState;
        }

        @Override
        public Sequence build() throws SequenceConstructionException {
            int uniqueSequenceIdentifier = this.factories.hashCode();
            Sequence sequence = new Sequence(this.size, uniqueSequenceIdentifier, SequenceBuilder.this.sequenceName);
            Iterator iter = this.data.getIterator();
            VertexDictionary previousVertices = this.initialState;
            int i = 0;
            while (iter.hasNext()) {
                SequenceItem item = new SequenceItem(i, uniqueSequenceIdentifier, SequenceBuilder.this.sequenceName);
                this.factories.forEach(factory -> factory.accept(item, iter.next()));
                SequenceBuilder.this.connectTransitionVariables(previousVertices, item, SequenceBuilder.this.transitionMapping);
                sequence.add(item);
                previousVertices = item;
                ++i;
            }
            return sequence;
        }
    }

    public class FromIterator
    implements SequenceData<T> {
        private Iterator<T> iterator;
        private int size;
        private final VertexDictionary initialState;

        private FromIterator(Iterator<T> iterator, int size, VertexDictionary initialState, Map<VertexLabel, VertexLabel> transitionMapping) {
            this.iterator = iterator;
            this.size = size;
            this.initialState = initialState;
        }

        @Override
        public Iterator<T> getIterator() {
            return this.iterator;
        }

        public FromDataFactories withFactory(BiConsumer<SequenceItem, T> factory) {
            return this.withFactories(Collections.singleton(factory));
        }

        public FromDataFactories withFactories(Collection<BiConsumer<SequenceItem, T>> factories) {
            return new FromDataFactories(factories, this, this.size, this.initialState);
        }
    }

    public class FromCount
    implements ItemCount {
        private final int count;
        private final VertexDictionary initialState;

        public FromCount(int count, VertexDictionary initialState) {
            this.count = count;
            this.initialState = initialState;
        }

        @Override
        public int getCount() {
            return this.count;
        }

        public FromCountFactories withFactory(Consumer<SequenceItem> factory) {
            return this.withFactories(Collections.singleton(factory));
        }

        public FromCountFactories withFactories(Collection<Consumer<SequenceItem>> factories) {
            return new FromCountFactories(factories, this, this.initialState, SequenceBuilder.this.transitionMapping);
        }
    }

    private static interface SequenceFactory {
        public Sequence build();
    }

    private static interface SequenceData<T> {
        public Iterator<T> getIterator();
    }

    private static interface ItemCount {
        public int getCount();
    }
}

