/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.dbl.nonprobabilistic.diff;

import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.vertices.VertexId;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialDerivative;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.diff.PartialsOf;
import java.util.HashMap;
import java.util.Map;

public class LogProbGradients {
    private final Map<VertexId, DoubleTensor> partials = new HashMap<VertexId, DoubleTensor>();

    public LogProbGradients add(LogProbGradients addition) {
        return this.add(addition.partials);
    }

    public LogProbGradients add(Map<VertexId, DoubleTensor> addition) {
        for (Map.Entry<VertexId, DoubleTensor> entry : addition.entrySet()) {
            this.putPartial(entry.getKey(), entry.getValue());
        }
        return this;
    }

    public LogProbGradients add(PartialsOf addition) {
        for (Map.Entry<VertexId, PartialDerivative> entry : addition.asMap().entrySet()) {
            this.putPartial(entry.getKey(), entry.getValue().get());
        }
        return this;
    }

    private void putPartial(VertexId id, DoubleTensor value) {
        DoubleTensor existingPartialDerivative = this.partials.get(id);
        if (existingPartialDerivative == null) {
            this.partials.put(id, (DoubleTensor)value.duplicate());
        } else {
            this.partials.put(id, existingPartialDerivative.plusInPlace(value));
        }
    }

    public DoubleTensor getWithRespectTo(VertexId id) {
        return this.partials.get(id);
    }

    public void putWithRespectTo(VertexId id, DoubleTensor partial) {
        this.partials.put(id, partial);
    }

    public Map<VertexId, DoubleTensor> getPartials() {
        return this.partials;
    }
}

