package io.improbable.keanu.vertices;

import com.google.common.collect.ImmutableSet;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.algorithms.graphtraversal.DiscoverGraph;
import io.improbable.keanu.algorithms.graphtraversal.VertexValuePropagation;
import io.improbable.keanu.network.NetworkLoader;
import io.improbable.keanu.network.NetworkSaver;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.vertices.dbl.Differentiable;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.generic.nonprobabilistic.PrintVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;

/* loaded from: input_file:io/improbable/keanu/vertices/Vertex.class */
public abstract class Vertex<T> implements Observable<T>, Variable<T, VertexState<T>> {
    private final VertexId id;
    private final long[] initialShape;
    private Set<Vertex> children;
    private Set<Vertex> parents;
    private VertexState<T> state;
    private VertexLabel label;

    public Vertex() {
        this(Tensor.SCALAR_SHAPE);
    }

    public Vertex(long[] jArr) {
        this.id = new VertexId();
        this.children = new TreeSet(Comparator.comparing((v0) -> {
            return v0.getId();
        }));
        this.parents = Collections.emptySet();
        this.label = null;
        this.initialShape = jArr;
        this.state = VertexState.nullState();
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <V extends Vertex<T>> V setLabel(VertexLabel vertexLabel) {
        this.label = vertexLabel;
        return this;
    }

    public <V extends Vertex<T>> V setLabel(String str) {
        return (V) setLabel(new VertexLabel(str));
    }

    public VertexLabel getLabel() {
        return this.label;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <V extends Vertex<T>> V removeLabel() {
        this.label = null;
        return this;
    }

    public final T lazyEval() {
        VertexValuePropagation.lazyEval(this);
        return getValue();
    }

    public final T eval() {
        VertexValuePropagation.eval(this);
        return getValue();
    }

    public final boolean isProbabilistic() {
        return this instanceof Probabilistic;
    }

    public final boolean isDifferentiable() {
        return this instanceof Differentiable;
    }

    public void setValue(T t) {
        if (this.state.isObserved()) {
            return;
        }
        this.state = new VertexState<>(t, false);
    }

    @Override // io.improbable.keanu.algorithms.Variable
    public T getValue() {
        return hasValue() ? this.state.getValue() : lazyEval();
    }

    @Override // io.improbable.keanu.algorithms.Variable
    public VertexState<T> getState() {
        return this.state;
    }

    public void setState(VertexState<T> vertexState) {
        this.state = vertexState;
    }

    public boolean hasValue() {
        return this.state.getValue() != null;
    }

    @Override // io.improbable.keanu.algorithms.Variable
    public long[] getShape() {
        return this.state.getValue() instanceof Tensor ? ((Tensor) this.state.getValue()).getShape() : this.initialShape;
    }

    public int getRank() {
        return getShape().length;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <V extends Vertex<T>> V print() {
        new PrintVertex(this);
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public <V extends Vertex<T>> V print(String str, boolean z) {
        new PrintVertex(this, str, z);
        return this;
    }

    public void setAndCascade(T t) {
        setValue(t);
        VertexValuePropagation.cascadeUpdate(this);
    }

    @Override // io.improbable.keanu.vertices.Observable
    public void observe(T t) {
        if (!isObservable(getClass())) {
            throw new UnsupportedOperationException("This type of vertex does not support being observed");
        }
        this.state = new VertexState<>(t, true);
    }

    private static boolean isObservable(Class<? extends Vertex> cls) {
        return Probabilistic.class.isAssignableFrom(cls) || (!IntegerVertex.class.isAssignableFrom(cls) && !DoubleVertex.class.isAssignableFrom(cls));
    }

    public void observeOwnValue() {
        observe(getValue());
    }

    @Override // io.improbable.keanu.vertices.Observable
    public void unobserve() {
        this.state = new VertexState<>(this.state.getValue(), false);
    }

    @Override // io.improbable.keanu.vertices.Observable
    public boolean isObserved() {
        return this.state.isObserved();
    }

    @Override // io.improbable.keanu.vertices.Observable
    public Optional<T> getObservedValue() {
        return this.state.getObservedValue();
    }

    @Override // io.improbable.keanu.algorithms.Variable
    public VariableReference getReference() {
        return getId();
    }

    public VertexId getId() {
        return this.id;
    }

    public int getIndentation() {
        return this.id.getIndentation();
    }

    public Set<Vertex> getChildren() {
        return Collections.unmodifiableSet(this.children);
    }

    public void addChild(Vertex<?> vertex) {
        this.children.add(vertex);
    }

    public void setParents(Collection<? extends Vertex> collection) {
        this.parents = Collections.emptySet();
        addParents(collection);
    }

    public void setParents(Vertex<?>... vertexArr) {
        setParents(Arrays.asList(vertexArr));
    }

    public void addParents(Collection<? extends Vertex> collection) {
        this.parents = ImmutableSet.builder().addAll(getParents()).addAll(collection).build();
        collection.forEach(vertex -> {
            vertex.addChild(this);
        });
    }

    public void addParent(Vertex<?> vertex) {
        addParents(ImmutableSet.of(vertex));
    }

    public Set<Vertex> getParents() {
        return this.parents;
    }

    public int getDegree() {
        return this.children.size() + this.parents.size();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return this.id.equals(((Vertex) obj).id);
    }

    public int hashCode() {
        return this.id.hashCode();
    }

    public Set<Vertex> getConnectedGraph() {
        return DiscoverGraph.getEntireGraph(this);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(getId());
        if (getLabel() != null) {
            sb.append(" (").append(getLabel()).append(")");
        }
        sb.append(": ");
        sb.append(getClass().getSimpleName());
        if (hasValue()) {
            sb.append("(" + getValue() + ")");
        }
        return sb.toString();
    }

    public void save(NetworkSaver networkSaver) {
        networkSaver.save(this);
    }

    public void saveValue(NetworkSaver networkSaver) {
        networkSaver.saveValue(this);
    }

    public void loadValue(NetworkLoader networkLoader) {
        networkLoader.loadValue(this);
    }
}
