package io.improbable.keanu.algorithms.mcmc.nuts;

import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/algorithms/mcmc/nuts/LeapfrogState.class */
public final class LeapfrogState {
    private final Map<VariableReference, DoubleTensor> position;
    private final Map<VariableReference, DoubleTensor> momentum;
    private final Map<VariableReference, DoubleTensor> velocity;
    private final Map<? extends VariableReference, DoubleTensor> gradient;
    private final double kineticEnergy;
    private final double logProb;
    private final double energy;

    public LeapfrogState(Map<VariableReference, DoubleTensor> map, Map<VariableReference, DoubleTensor> map2, Map<? extends VariableReference, DoubleTensor> map3, double d, Potential potential) {
        this.position = map;
        this.momentum = map2;
        this.velocity = potential.getVelocity(map2);
        this.gradient = map3;
        this.kineticEnergy = potential.getKineticEnergy(map2, this.velocity);
        this.energy = this.kineticEnergy - d;
        this.logProb = d;
    }

    public Map<VariableReference, DoubleTensor> getPosition() {
        return this.position;
    }

    public Map<VariableReference, DoubleTensor> getMomentum() {
        return this.momentum;
    }

    public Map<VariableReference, DoubleTensor> getVelocity() {
        return this.velocity;
    }

    public Map<? extends VariableReference, DoubleTensor> getGradient() {
        return this.gradient;
    }

    public double getKineticEnergy() {
        return this.kineticEnergy;
    }

    public double getLogProb() {
        return this.logProb;
    }

    public double getEnergy() {
        return this.energy;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LeapfrogState)) {
            return false;
        }
        LeapfrogState leapfrogState = (LeapfrogState) obj;
        Map<VariableReference, DoubleTensor> position = getPosition();
        Map<VariableReference, DoubleTensor> position2 = leapfrogState.getPosition();
        if (position == null) {
            if (position2 != null) {
                return false;
            }
        } else if (!position.equals(position2)) {
            return false;
        }
        Map<VariableReference, DoubleTensor> momentum = getMomentum();
        Map<VariableReference, DoubleTensor> momentum2 = leapfrogState.getMomentum();
        if (momentum == null) {
            if (momentum2 != null) {
                return false;
            }
        } else if (!momentum.equals(momentum2)) {
            return false;
        }
        Map<VariableReference, DoubleTensor> velocity = getVelocity();
        Map<VariableReference, DoubleTensor> velocity2 = leapfrogState.getVelocity();
        if (velocity == null) {
            if (velocity2 != null) {
                return false;
            }
        } else if (!velocity.equals(velocity2)) {
            return false;
        }
        Map<? extends VariableReference, DoubleTensor> gradient = getGradient();
        Map<? extends VariableReference, DoubleTensor> gradient2 = leapfrogState.getGradient();
        if (gradient == null) {
            if (gradient2 != null) {
                return false;
            }
        } else if (!gradient.equals(gradient2)) {
            return false;
        }
        return Double.compare(getKineticEnergy(), leapfrogState.getKineticEnergy()) == 0 && Double.compare(getLogProb(), leapfrogState.getLogProb()) == 0 && Double.compare(getEnergy(), leapfrogState.getEnergy()) == 0;
    }

    public int hashCode() {
        Map<VariableReference, DoubleTensor> position = getPosition();
        int hashCode = (1 * 59) + (position == null ? 43 : position.hashCode());
        Map<VariableReference, DoubleTensor> momentum = getMomentum();
        int hashCode2 = (hashCode * 59) + (momentum == null ? 43 : momentum.hashCode());
        Map<VariableReference, DoubleTensor> velocity = getVelocity();
        int hashCode3 = (hashCode2 * 59) + (velocity == null ? 43 : velocity.hashCode());
        Map<? extends VariableReference, DoubleTensor> gradient = getGradient();
        int hashCode4 = (hashCode3 * 59) + (gradient == null ? 43 : gradient.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getKineticEnergy());
        int i = (hashCode4 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getLogProb());
        int i2 = (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
        long doubleToLongBits3 = Double.doubleToLongBits(getEnergy());
        return (i2 * 59) + ((int) ((doubleToLongBits3 >>> 32) ^ doubleToLongBits3));
    }

    public String toString() {
        return "LeapfrogState(position=" + getPosition() + ", momentum=" + getMomentum() + ", velocity=" + getVelocity() + ", gradient=" + getGradient() + ", kineticEnergy=" + getKineticEnergy() + ", logProb=" + getLogProb() + ", energy=" + getEnergy() + ")";
    }

    public LeapfrogState(Map<VariableReference, DoubleTensor> map, Map<VariableReference, DoubleTensor> map2, Map<VariableReference, DoubleTensor> map3, Map<? extends VariableReference, DoubleTensor> map4, double d, double d2, double d3) {
        this.position = map;
        this.momentum = map2;
        this.velocity = map3;
        this.gradient = map4;
        this.kineticEnergy = d;
        this.logProb = d2;
        this.energy = d3;
    }
}
