/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.diff;

import com.google.common.base.Preconditions;
import io.improbable.keanu.network.LambdaSection;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.VertexId;
import io.improbable.keanu.vertices.dbl.Differentiator;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.AutoDiffBroadcast;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.LogProbGradients;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialsOf;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

public class LogProbGradientCalculator {
    private final Set<? extends Vertex<?>> logProbOfVertices;
    private final Set<? extends Vertex<?>> wrtVertices;
    private final Map<Vertex, Set<DoubleVertex>> parentToLatentLookup;
    private final Map<Vertex, Set<DoubleVertex>> verticesWithNonzeroDiffWrtLatent;

    public LogProbGradientCalculator(List<? extends Vertex> logProbOfVerticesList, List<? extends Vertex<?>> wrtVerticesList) {
        this.logProbOfVertices = new HashSet<Vertex>(logProbOfVerticesList);
        this.wrtVertices = new HashSet(wrtVerticesList);
        this.parentToLatentLookup = this.getParentsThatAreConnectedToWrtVertices(this.logProbOfVertices);
        this.verticesWithNonzeroDiffWrtLatent = this.getVerticesWithNonzeroDiffWrt(this.logProbOfVertices, this.parentToLatentLookup);
    }

    public Map<VertexId, DoubleTensor> getJointLogProbGradientWrtLatents() {
        LogProbGradients totalLogProbGradients = new LogProbGradients();
        for (Vertex<?> ofVertex : this.logProbOfVertices) {
            LogProbGradients logProbGradientOfVertex = this.reverseModeLogProbGradientWrtLatents(ofVertex);
            totalLogProbGradients.add(logProbGradientOfVertex);
        }
        return totalLogProbGradients.getPartials();
    }

    private Map<Vertex, Set<DoubleVertex>> getVerticesWithNonzeroDiffWrt(Set<? extends Vertex<?>> ofVertices, Map<Vertex, Set<DoubleVertex>> parentToWrtVertices) {
        return ofVertices.stream().collect(Collectors.toMap(v -> v, v -> {
            Set parents = v.getParents().stream().filter(parent -> parent instanceof DoubleVertex).map(parent -> (DoubleVertex)parent).filter(parentToWrtVertices::containsKey).collect(Collectors.toSet());
            if (!v.isObserved()) {
                parents.add((DoubleVertex)v);
            }
            return parents;
        }));
    }

    private Map<Vertex, Set<DoubleVertex>> getParentsThatAreConnectedToWrtVertices(Set<? extends Vertex> ofVertices) {
        HashMap<Vertex, Set<DoubleVertex>> probabilisticParentLookup = new HashMap<Vertex, Set<DoubleVertex>>();
        for (Vertex vertex : ofVertices) {
            Set<Vertex> parents = vertex.getParents();
            for (Vertex parent : parents) {
                LambdaSection upstreamLambdaSection = LambdaSection.getUpstreamLambdaSection(parent, false);
                Set<Vertex> latentAndObservedVertices = upstreamLambdaSection.getLatentAndObservedVertices();
                Set latentVertices = latentAndObservedVertices.stream().filter(this::isLatentDoubleVertexAndInWrtTo).map(v -> (DoubleVertex)v).collect(Collectors.toSet());
                if (latentVertices.isEmpty()) continue;
                probabilisticParentLookup.put(parent, latentVertices);
            }
        }
        return probabilisticParentLookup;
    }

    private boolean isLatentDoubleVertexAndInWrtTo(Vertex v) {
        return !v.isObserved() && this.wrtVertices.contains(v) && v instanceof DoubleVertex;
    }

    private LogProbGradients reverseModeLogProbGradientWrtLatents(Vertex ofVertex) {
        Preconditions.checkArgument((boolean)(ofVertex instanceof Probabilistic), (String)"Cannot get logProb gradient on non-probabilistic vertex %s", (Object)ofVertex);
        Set<DoubleVertex> verticesWithNonzeroDiff = this.verticesWithNonzeroDiffWrtLatent.get(ofVertex);
        Map<Vertex, DoubleTensor> dlogProbOfVertexWrtVertices = ((Probabilistic)((Object)ofVertex)).dLogProbAtValue(verticesWithNonzeroDiff);
        LogProbGradients dOfWrtLatentsAccumulated = new LogProbGradients();
        for (Map.Entry<Vertex, DoubleTensor> dlogProbWrtVertex : dlogProbOfVertexWrtVertices.entrySet()) {
            DoubleVertex vertexWithDiff = (DoubleVertex)dlogProbWrtVertex.getKey();
            DoubleTensor dLogProbOfWrtVertexWithDiff = dlogProbWrtVertex.getValue();
            if (vertexWithDiff.equals(ofVertex)) {
                dOfWrtLatentsAccumulated.putWithRespectTo(vertexWithDiff.getId(), dLogProbOfWrtVertexWithDiff);
                continue;
            }
            PartialDerivative partialWrtVertexWithDiff = new PartialDerivative(dLogProbOfWrtVertexWithDiff);
            PartialDerivative correctForScalarReverse = AutoDiffBroadcast.correctForBroadcastPartialReverse(partialWrtVertexWithDiff, ofVertex.getShape(), vertexWithDiff.getShape());
            PartialsOf dOfWrtLatentsContributionFromParent = Differentiator.reverseModeAutoDiff(vertexWithDiff, correctForScalarReverse, this.parentToLatentLookup.get(vertexWithDiff));
            dOfWrtLatentsAccumulated = dOfWrtLatentsAccumulated.add(dOfWrtLatentsContributionFromParent);
        }
        return dOfWrtLatentsAccumulated;
    }
}

