/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices;

import com.google.common.collect.ImmutableSet;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.graphtraversal.DiscoverGraph;
import io.improbable.keanu.algorithms.graphtraversal.VertexValuePropagation;
import io.improbable.keanu.network.NetworkLoader;
import io.improbable.keanu.network.NetworkSaver;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.vertices.Observable;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.VertexId;
import io.improbable.keanu.vertices.VertexLabel;
import io.improbable.keanu.vertices.VertexState;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.generic.nonprobabilistic.PrintVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;

public abstract class Vertex<T>
implements Observable<T>,
Variable<T, VertexState<T>> {
    private final VertexId id = new VertexId();
    private final long[] initialShape;
    private Set<Vertex> children = new TreeSet<Vertex>(Comparator.comparing(Vertex::getId));
    private Set<Vertex> parents = Collections.emptySet();
    private VertexState<T> state;
    private VertexLabel label = null;

    public Vertex() {
        this(Tensor.SCALAR_SHAPE);
    }

    public Vertex(long[] initialShape) {
        this.initialShape = initialShape;
        this.state = VertexState.nullState();
    }

    public <V extends Vertex<T>> V setLabel(VertexLabel label) {
        this.label = label;
        return (V)this;
    }

    public <V extends Vertex<T>> V setLabel(String label) {
        return this.setLabel(new VertexLabel(label));
    }

    public VertexLabel getLabel() {
        return this.label;
    }

    public <V extends Vertex<T>> V removeLabel() {
        this.label = null;
        return (V)this;
    }

    public final T lazyEval() {
        VertexValuePropagation.lazyEval(this);
        return this.getValue();
    }

    public final T eval() {
        VertexValuePropagation.eval(this);
        return this.getValue();
    }

    public final boolean isProbabilistic() {
        return this instanceof Probabilistic;
    }

    public final boolean isDifferentiable() {
        return this instanceof Differentiable;
    }

    public void setValue(T value) {
        if (!this.state.isObserved()) {
            this.state = new VertexState<T>(value, false);
        }
    }

    @Override
    public T getValue() {
        return this.hasValue() ? this.state.getValue() : this.lazyEval();
    }

    @Override
    public VertexState<T> getState() {
        return this.state;
    }

    public void setState(VertexState<T> newState) {
        this.state = newState;
    }

    public boolean hasValue() {
        return this.state.getValue() != null;
    }

    @Override
    public long[] getShape() {
        if (this.state.getValue() instanceof Tensor) {
            return ((Tensor)this.state.getValue()).getShape();
        }
        return this.initialShape;
    }

    public int getRank() {
        return this.getShape().length;
    }

    public <V extends Vertex<T>> V print() {
        new PrintVertex(this);
        return (V)this;
    }

    public <V extends Vertex<T>> V print(String message, boolean printData) {
        new PrintVertex(this, message, printData);
        return (V)this;
    }

    public void setAndCascade(T value) {
        this.setValue(value);
        VertexValuePropagation.cascadeUpdate(this);
    }

    @Override
    public void observe(T value) {
        if (!Vertex.isObservable(this.getClass())) {
            throw new UnsupportedOperationException("This type of vertex does not support being observed");
        }
        this.state = new VertexState<T>(value, true);
    }

    private static boolean isObservable(Class<? extends Vertex> v) {
        boolean isProbabilistic = Probabilistic.class.isAssignableFrom(v);
        boolean isNotDoubleOrIntegerVertex = !IntegerVertex.class.isAssignableFrom(v) && !DoubleVertex.class.isAssignableFrom(v);
        return isProbabilistic || isNotDoubleOrIntegerVertex;
    }

    public void observeOwnValue() {
        this.observe(this.getValue());
    }

    @Override
    public void unobserve() {
        this.state = new VertexState<T>(this.state.getValue(), false);
    }

    @Override
    public boolean isObserved() {
        return this.state.isObserved();
    }

    @Override
    public Optional<T> getObservedValue() {
        return this.state.getObservedValue();
    }

    @Override
    public VariableReference getReference() {
        return this.getId();
    }

    public VertexId getId() {
        return this.id;
    }

    public int getIndentation() {
        return this.id.getIndentation();
    }

    public Set<Vertex> getChildren() {
        return Collections.unmodifiableSet(this.children);
    }

    public void addChild(Vertex<?> v) {
        this.children.add(v);
    }

    public void setParents(Collection<? extends Vertex> parents) {
        this.parents = Collections.emptySet();
        this.addParents(parents);
    }

    public void setParents(Vertex<?> ... parents) {
        this.setParents(Arrays.asList(parents));
    }

    public void addParents(Collection<? extends Vertex> parents) {
        this.parents = ImmutableSet.builder().addAll(this.getParents()).addAll(parents).build();
        parents.forEach(p -> p.addChild(this));
    }

    public void addParent(Vertex<?> parent) {
        this.addParents((Collection<Vertex>)ImmutableSet.of(parent));
    }

    public Set<Vertex> getParents() {
        return this.parents;
    }

    public int getDegree() {
        return this.children.size() + this.parents.size();
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        Vertex vertex = (Vertex)o;
        return this.id.equals(vertex.id);
    }

    public int hashCode() {
        return this.id.hashCode();
    }

    public Set<Vertex> getConnectedGraph() {
        return DiscoverGraph.getEntireGraph(this);
    }

    public String toString() {
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append(this.getId());
        if (this.getLabel() != null) {
            stringBuilder.append(" (").append(this.getLabel()).append(")");
        }
        stringBuilder.append(": ");
        stringBuilder.append(this.getClass().getSimpleName());
        if (this.hasValue()) {
            stringBuilder.append("(" + this.getValue() + ")");
        }
        return stringBuilder.toString();
    }

    public void save(NetworkSaver netSaver) {
        netSaver.save(this);
    }

    public void saveValue(NetworkSaver netSaver) {
        netSaver.saveValue(this);
    }

    public void loadValue(NetworkLoader loader) {
        loader.loadValue(this);
    }
}

