/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms.graphtraversal;

import io.improbable.keanu.algorithms.graphtraversal.BreadthFirstSearch;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;

public final class DifferentiableChecker {
    public static boolean isDifferentiableWrtLatents(Collection<Vertex> vertices) {
        if (!DifferentiableChecker.allProbabilisticAreDoubleOrObserved(vertices)) {
            return false;
        }
        Set<Vertex> allParents = DifferentiableChecker.allParentsOf(vertices);
        HashSet<Vertex> constantValueVerticesCache = new HashSet<Vertex>();
        return DifferentiableChecker.diffableOrConstantUptoNextRV(allParents, constantValueVerticesCache);
    }

    private static boolean allProbabilisticAreDoubleOrObserved(Collection<Vertex> vertices) {
        return vertices.stream().filter(Vertex::isProbabilistic).allMatch(DifferentiableChecker::isDoubleOrObserved);
    }

    private static boolean isDoubleOrObserved(Vertex v) {
        return v instanceof DoubleVertex || v.isObserved();
    }

    private static Set<Vertex> allParentsOf(Collection<Vertex> vertices) {
        HashSet<Vertex> allParents = new HashSet<Vertex>();
        for (Vertex vertex : vertices) {
            allParents.addAll(vertex.getParents());
        }
        return allParents;
    }

    private static boolean diffableOrConstantUptoNextRV(Collection<Vertex> vertices, Set<Vertex> constantValueVerticesCache) {
        return BreadthFirstSearch.bfsWithFailureCondition(vertices, vertex -> DifferentiableChecker.isNonDiffableAndNotConstant(vertex, constantValueVerticesCache), DifferentiableChecker::getParentsIfVertexIsNotProbabilistic, BreadthFirstSearch::doNothing);
    }

    private static Collection<Vertex> getParentsIfVertexIsNotProbabilistic(Vertex visiting) {
        return visiting.isProbabilistic() ? Collections.emptySet() : visiting.getParents();
    }

    private static boolean isNonDiffableAndNotConstant(Vertex vertex, Set<Vertex> constantValueVerticesCache) {
        return !vertex.isDifferentiable() && !DifferentiableChecker.isVertexValueConstant(vertex, constantValueVerticesCache);
    }

    private static boolean isVertexValueConstant(Vertex vertex, Set<Vertex> constantValueVerticesCache) {
        if (DifferentiableChecker.isValueKnownToBeConstant(vertex, constantValueVerticesCache)) {
            return true;
        }
        return BreadthFirstSearch.bfsWithFailureCondition(Collections.singletonList(vertex), DifferentiableChecker::isUnobservedProbabilistic, visiting -> DifferentiableChecker.getParentsIfValueNotKnownToBeConstant(visiting, constantValueVerticesCache), constantValueVerticesCache::addAll);
    }

    private static boolean isUnobservedProbabilistic(Vertex vertex) {
        return vertex.isProbabilistic() && !vertex.isObserved();
    }

    private static Collection<Vertex> getParentsIfValueNotKnownToBeConstant(Vertex visiting, Set<Vertex> constantValueVerticesCache) {
        return DifferentiableChecker.isValueKnownToBeConstant(visiting, constantValueVerticesCache) ? Collections.emptySet() : visiting.getParents();
    }

    private static boolean isValueKnownToBeConstant(Vertex vertex, Set<Vertex> constantValueVerticesCache) {
        return vertex instanceof ConstantVertex || vertex.isObserved() || constantValueVerticesCache.contains(vertex);
    }

    private DifferentiableChecker() {
        throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
    }
}

