/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.network;

import io.improbable.keanu.network.Propagation;
import io.improbable.keanu.vertices.Vertex;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Collectors;

public final class TransitiveClosure {
    private static final Predicate<Vertex> ADD_ALL = vertex -> true;
    private static final Predicate<Vertex> PROBABILISTIC_OR_OBSERVED_ONLY = vertex -> vertex.isObserved() || vertex.isProbabilistic();
    private final Set<Vertex> allVertices;
    private final Set<Vertex> latentAndObservedVertices;

    private TransitiveClosure(Set<Vertex> allVertices) {
        this.allVertices = allVertices;
        this.latentAndObservedVertices = allVertices.stream().filter(PROBABILISTIC_OR_OBSERVED_ONLY).collect(Collectors.toSet());
    }

    public static TransitiveClosure getUpstreamVertices(Vertex<?> aVertex, boolean includeNonProbabilistic) {
        return TransitiveClosure.getUpstreamVerticesForCollection(Collections.singletonList(aVertex), includeNonProbabilistic);
    }

    public static TransitiveClosure getDownstreamVertices(Vertex<?> aVertex, boolean includeNonProbabilistic) {
        return TransitiveClosure.getDownstreamVerticesForCollection(Collections.singletonList(aVertex), includeNonProbabilistic);
    }

    public static TransitiveClosure getUpstreamVerticesForCollection(List<Vertex> vertices, boolean includeNonProbabilistic) {
        Predicate<Vertex> shouldAdd = includeNonProbabilistic ? ADD_ALL : PROBABILISTIC_OR_OBSERVED_ONLY;
        Set<Vertex> upstreamVertices = Propagation.getVertices(vertices, Vertex::getParents, v -> false, shouldAdd);
        return new TransitiveClosure(upstreamVertices);
    }

    public static TransitiveClosure getDownstreamVerticesForCollection(List<Vertex> vertices, boolean includeNonProbabilistic) {
        Predicate<Vertex> shouldAdd = includeNonProbabilistic ? ADD_ALL : PROBABILISTIC_OR_OBSERVED_ONLY;
        Set<Vertex> downstreamVertices = Propagation.getVertices(vertices, Vertex::getChildren, v -> false, shouldAdd);
        return new TransitiveClosure(downstreamVertices);
    }

    public Set<Vertex> getAllVertices() {
        return this.allVertices;
    }

    public Set<Vertex> getLatentAndObservedVertices() {
        return this.latentAndObservedVertices;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof TransitiveClosure)) {
            return false;
        }
        TransitiveClosure other = (TransitiveClosure)o;
        Set<Vertex> this$allVertices = this.getAllVertices();
        Set<Vertex> other$allVertices = other.getAllVertices();
        if (this$allVertices == null ? other$allVertices != null : !((Object)this$allVertices).equals(other$allVertices)) {
            return false;
        }
        Set<Vertex> this$latentAndObservedVertices = this.getLatentAndObservedVertices();
        Set<Vertex> other$latentAndObservedVertices = other.getLatentAndObservedVertices();
        return !(this$latentAndObservedVertices == null ? other$latentAndObservedVertices != null : !((Object)this$latentAndObservedVertices).equals(other$latentAndObservedVertices));
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Set<Vertex> $allVertices = this.getAllVertices();
        result = result * 59 + ($allVertices == null ? 43 : ((Object)$allVertices).hashCode());
        Set<Vertex> $latentAndObservedVertices = this.getLatentAndObservedVertices();
        result = result * 59 + ($latentAndObservedVertices == null ? 43 : ((Object)$latentAndObservedVertices).hashCode());
        return result;
    }

    public String toString() {
        return "TransitiveClosure(allVertices=" + this.getAllVertices() + ", latentAndObservedVertices=" + this.getLatentAndObservedVertices() + ")";
    }
}

