package io.improbable.keanu.vertices.dbl.nonprobabilistic.diff;

import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.VertexId;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/vertices/dbl/nonprobabilistic/diff/LogProbGradients.class */
public class LogProbGradients {
    private final Map<VertexId, DoubleTensor> partials = new HashMap();

    public LogProbGradients add(LogProbGradients logProbGradients) {
        return add(logProbGradients.partials);
    }

    public LogProbGradients add(Map<VertexId, DoubleTensor> map) {
        for (Map.Entry<VertexId, DoubleTensor> entry : map.entrySet()) {
            putPartial(entry.getKey(), entry.getValue());
        }
        return this;
    }

    public LogProbGradients add(PartialsOf partialsOf) {
        for (Map.Entry<VertexId, PartialDerivative> entry : partialsOf.asMap().entrySet()) {
            putPartial(entry.getKey(), entry.getValue().get());
        }
        return this;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void putPartial(VertexId vertexId, DoubleTensor doubleTensor) {
        DoubleTensor doubleTensor2 = this.partials.get(vertexId);
        if (doubleTensor2 == null) {
            this.partials.put(vertexId, doubleTensor.duplicate());
        } else {
            this.partials.put(vertexId, doubleTensor2.plusInPlace(doubleTensor));
        }
    }

    public DoubleTensor getWithRespectTo(VertexId vertexId) {
        return this.partials.get(vertexId);
    }

    public void putWithRespectTo(VertexId vertexId, DoubleTensor doubleTensor) {
        this.partials.put(vertexId, doubleTensor);
    }

    public Map<VertexId, DoubleTensor> getPartials() {
        return this.partials;
    }
}
