/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.distributions.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.Distribution;
import io.improbable.keanu.tensor.Tensor;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.generic.GenericTensor;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;

public class Categorical<CATEGORY>
implements Distribution<GenericTensor<CATEGORY>> {
    private final Map<CATEGORY, DoubleTensor> selectableValues;
    private final List<CATEGORY> categoryOrder;

    public static <CAT> Categorical<CAT> withParameters(Map<CAT, DoubleTensor> selectableValues) {
        return new Categorical<CAT>(selectableValues);
    }

    private Categorical(Map<CATEGORY, DoubleTensor> selectableValues) {
        this.selectableValues = new LinkedHashMap<CATEGORY, DoubleTensor>(selectableValues);
        this.categoryOrder = new ArrayList<CATEGORY>(this.selectableValues.keySet());
    }

    @Override
    public GenericTensor<CATEGORY> sample(long[] shape, KeanuRandom random) {
        DoubleTensor sumOfProbabilities = this.getSumOfProbabilities(shape);
        DoubleTensor p = random.nextDouble(shape);
        DoubleTensor sum = DoubleTensor.zeros(shape);
        CATEGORY lastValue = this.categoryOrder.get(this.categoryOrder.size() - 1);
        GenericTensor<CATEGORY> sample = GenericTensor.createFilled(lastValue, shape);
        BooleanTensor sampleValuesSetSoFar = BooleanTensor.falses(shape);
        for (CATEGORY category : this.categoryOrder) {
            DoubleTensor probabilitiesForCategory = this.selectableValues.get(category);
            DoubleTensor normalizedProbabilities = probabilitiesForCategory.div(sumOfProbabilities);
            sum = sum.plus(normalizedProbabilities);
            BooleanTensor maskForUnassignedSampleValues = sampleValuesSetSoFar.xor(sum.greaterThan(p));
            sample = maskForUnassignedSampleValues.where(Tensor.scalar(category), sample);
            sampleValuesSetSoFar.orInPlace(maskForUnassignedSampleValues);
            if (!sampleValuesSetSoFar.allTrue()) continue;
            break;
        }
        return sample;
    }

    @Override
    public DoubleTensor logProb(GenericTensor<CATEGORY> x) {
        DoubleTensor sumOfProbabilities = this.getSumOfProbabilities(x.getShape());
        DoubleTensor logProb = DoubleTensor.zeros(x.getShape());
        for (Map.Entry<CATEGORY, DoubleTensor> entry : this.selectableValues.entrySet()) {
            DoubleTensor xEqualToEntryKeyMask = x.elementwiseEquals((CATEGORY)GenericTensor.createFilled(entry.getKey(), x.getShape())).toDoubleMask();
            logProb = (DoubleTensor)logProb.plus(xEqualToEntryKeyMask.timesInPlace(entry.getValue().div(sumOfProbabilities).logInPlace()));
        }
        return logProb;
    }

    private boolean containsNonPositiveEntry(DoubleTensor sumOfProbabilities) {
        return !sumOfProbabilities.lessThanOrEqual(0.0).allFalse();
    }

    private DoubleTensor getSumOfProbabilities(long[] shape) {
        DoubleTensor sumOfProbabilities = DoubleTensor.zeros(shape);
        for (DoubleTensor p : this.selectableValues.values()) {
            sumOfProbabilities = sumOfProbabilities.plus(p);
        }
        if (this.containsNonPositiveEntry(sumOfProbabilities)) {
            throw new IllegalArgumentException("Cannot sample from a zero probability setup.");
        }
        return sumOfProbabilities;
    }
}

