/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.vertices.generic.probabilistic.discrete;

import io.improbable.keanu.KeanuRandom;
import io.improbable.keanu.distributions.discrete.Categorical;
import io.improbable.keanu.tensor.TensorShape;
import io.improbable.keanu.tensor.TensorShapeValidation;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.generic.GenericTensor;
import io.improbable.keanu.vertices.ConstantVertex;
import io.improbable.keanu.vertices.NonSaveableVertex;
import io.improbable.keanu.vertices.Probabilistic;
import io.improbable.keanu.vertices.Vertex;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.nonprobabilistic.operators.unary.TakeVertex;
import io.improbable.keanu.vertices.dbl.probabilistic.DirichletVertex;
import io.improbable.keanu.vertices.generic.GenericTensorVertex;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

public class CategoricalVertex<CATEGORY>
extends GenericTensorVertex<CATEGORY>
implements Probabilistic<GenericTensor<CATEGORY>>,
NonSaveableVertex {
    private final Map<CATEGORY, DoubleVertex> selectableValues;

    public static <CATEGORY> CategoricalVertex<CATEGORY> of(Map<CATEGORY, Double> selectableValues) {
        return new CategoricalVertex<CATEGORY>(CategoricalVertex.toDoubleVertices(selectableValues));
    }

    private static <CATEGORY> Map<CATEGORY, DoubleVertex> toDoubleVertices(Map<CATEGORY, Double> selectableValues) {
        return selectableValues.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> ConstantVertex.of((Double)e.getValue())));
    }

    public static <CATEGORY> CategoricalVertex<CATEGORY> of(DirichletVertex vertex, List<CATEGORY> categories) {
        long length = TensorShape.getLength(vertex.getShape());
        if (length != (long)categories.size()) {
            throw new IllegalArgumentException("Categories must have length of vertex's size");
        }
        int categoriesCount = categories.size();
        IntStream categoriesIndices = IntStream.range(0, categoriesCount);
        Map<Object, DoubleVertex> selectableValues = categoriesIndices.boxed().collect(Collectors.toMap(categories::get, index -> new TakeVertex(vertex, index.intValue())));
        return new CategoricalVertex<Object>(selectableValues);
    }

    public static CategoricalVertex<Integer> of(DirichletVertex vertex) {
        int categoriesCount = Math.toIntExact(TensorShape.getLength(vertex.getShape()));
        IntStream categories = IntStream.range(0, categoriesCount);
        List categoriesList = categories.boxed().collect(Collectors.toList());
        return CategoricalVertex.of(vertex, categoriesList);
    }

    public CategoricalVertex(long[] tensorShape, Map<CATEGORY, DoubleVertex> selectableValues) {
        super(tensorShape);
        TensorShapeValidation.checkTensorsMatchNonLengthOneShapeOrAreLengthOne(tensorShape, CategoricalVertex.selectableValuesShapes(selectableValues));
        this.selectableValues = selectableValues;
        this.setParents(this.selectableValues.values());
    }

    public CategoricalVertex(Map<CATEGORY, DoubleVertex> selectableValues) {
        this(TensorShapeValidation.checkHasOneNonLengthOneShapeOrAllLengthOne(CategoricalVertex.selectableValuesShapes(selectableValues)), selectableValues);
    }

    public Map<CATEGORY, DoubleVertex> getSelectableValues() {
        return this.selectableValues;
    }

    @Override
    public GenericTensor<CATEGORY> sample(KeanuRandom random) {
        Categorical<CATEGORY> categorical = Categorical.withParameters(this.selectableValuesMappedToDoubleTensor());
        return categorical.sample(this.getShape(), random);
    }

    @Override
    public double logProb(GenericTensor<CATEGORY> value) {
        Categorical<CATEGORY> categorical = Categorical.withParameters(this.selectableValuesMappedToDoubleTensor());
        return (Double)categorical.logProb(value).sum();
    }

    @Override
    public Map<Vertex, DoubleTensor> dLogProb(GenericTensor<CATEGORY> value, Set<? extends Vertex> withRespectTo) {
        return Collections.emptyMap();
    }

    private Map<CATEGORY, DoubleTensor> selectableValuesMappedToDoubleTensor() {
        return this.selectableValues.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> (DoubleTensor)((DoubleVertex)e.getValue()).getValue()));
    }

    private static long[][] selectableValuesShapes(Map<?, DoubleVertex> selectableValues) {
        return (long[][])selectableValues.values().stream().map(Vertex::getShape).toArray(x$0 -> new long[x$0][]);
    }
}

