package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.Distribution;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.generic.GenericTensor;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:io/improbable/keanu/distributions/discrete/Categorical.class */
public class Categorical<CATEGORY> implements Distribution<GenericTensor<CATEGORY>> {
    private final Map<CATEGORY, DoubleTensor> selectableValues;
    private final List<CATEGORY> categoryOrder;

    public static <CAT> Categorical<CAT> withParameters(Map<CAT, DoubleTensor> map) {
        return new Categorical<>(map);
    }

    private Categorical(Map<CATEGORY, DoubleTensor> map) {
        this.selectableValues = new LinkedHashMap(map);
        this.categoryOrder = new ArrayList(this.selectableValues.keySet());
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public GenericTensor<CATEGORY> sample(long[] jArr, KeanuRandom keanuRandom) {
        DoubleTensor sumOfProbabilities = getSumOfProbabilities(jArr);
        DoubleTensor nextDouble = keanuRandom.nextDouble(jArr);
        DoubleTensor zeros = DoubleTensor.zeros(jArr);
        GenericTensor<CATEGORY> createFilled = GenericTensor.createFilled(this.categoryOrder.get(this.categoryOrder.size() - 1), jArr);
        BooleanTensor falses = BooleanTensor.falses(jArr);
        for (CATEGORY category : this.categoryOrder) {
            zeros = (DoubleTensor) zeros.plus((DoubleTensor) this.selectableValues.get(category).div(sumOfProbabilities));
            BooleanTensor xor = falses.xor(zeros.greaterThan(nextDouble));
            createFilled = (GenericTensor) xor.where(Tensor.scalar(category), createFilled);
            falses.orInPlace(xor);
            if (falses.allTrue()) {
                break;
            }
        }
        return createFilled;
    }

    @Override // io.improbable.keanu.distributions.Distribution
    public DoubleTensor logProb(GenericTensor<CATEGORY> genericTensor) {
        DoubleTensor sumOfProbabilities = getSumOfProbabilities(genericTensor.getShape());
        DoubleTensor zeros = DoubleTensor.zeros(genericTensor.getShape());
        for (Map.Entry<CATEGORY, DoubleTensor> entry : this.selectableValues.entrySet()) {
            zeros = (DoubleTensor) zeros.plus((DoubleTensor) genericTensor.elementwiseEquals((Tensor) GenericTensor.createFilled(entry.getKey(), genericTensor.getShape())).toDoubleMask().timesInPlace(((DoubleTensor) entry.getValue().div(sumOfProbabilities)).logInPlace()));
        }
        return zeros;
    }

    private boolean containsNonPositiveEntry(DoubleTensor doubleTensor) {
        return !doubleTensor.lessThanOrEqual((DoubleTensor) Double.valueOf(0.0d)).allFalse();
    }

    private DoubleTensor getSumOfProbabilities(long[] jArr) {
        DoubleTensor zeros = DoubleTensor.zeros(jArr);
        Iterator<DoubleTensor> it = this.selectableValues.values().iterator();
        while (it.hasNext()) {
            zeros = (DoubleTensor) zeros.plus(it.next());
        }
        if (containsNonPositiveEntry(zeros)) {
            throw new IllegalArgumentException("Cannot sample from a zero probability setup.");
        }
        return zeros;
    }
}
