/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.algorithms;

import com.google.common.base.Preconditions;
import io.improbable.keanu.algorithms.NetworkSample;
import io.improbable.keanu.algorithms.Samples;
import io.improbable.keanu.algorithms.Variable;
import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.network.NetworkState;
import io.improbable.keanu.tensor.bool.BooleanTensor;
import io.improbable.keanu.tensor.dbl.DoubleTensor;
import io.improbable.keanu.tensor.intgr.IntegerTensor;
import io.improbable.keanu.vertices.bool.BooleanVertex;
import io.improbable.keanu.vertices.bool.BooleanVertexSamples;
import io.improbable.keanu.vertices.dbl.DoubleVertex;
import io.improbable.keanu.vertices.dbl.DoubleVertexSamples;
import io.improbable.keanu.vertices.intgr.IntegerVertex;
import io.improbable.keanu.vertices.intgr.IntegerVertexSamples;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class NetworkSamples {
    private static final Logger log = LoggerFactory.getLogger(NetworkSamples.class);
    private final Map<VariableReference, ? extends List> samplesByVariable;
    private final List<Double> logOfMasterPForEachSample;
    private final int size;

    public NetworkSamples(Map<VariableReference, ? extends List> samplesByVariable, List<Double> logOfMasterPForEachSample, int size) {
        this.samplesByVariable = samplesByVariable;
        this.logOfMasterPForEachSample = logOfMasterPForEachSample;
        this.size = size;
    }

    public static NetworkSamples from(List<NetworkSample> networkSamples) {
        HashMap samplesByVariable = new HashMap();
        ArrayList<Double> logOfMasterPForEachSample = new ArrayList<Double>();
        networkSamples.forEach(networkSample -> NetworkSamples.addSamplesForNetworkSample(networkSample, samplesByVariable));
        networkSamples.forEach(networkSample -> logOfMasterPForEachSample.add(networkSample.getLogOfMasterP()));
        return new NetworkSamples(samplesByVariable, logOfMasterPForEachSample, networkSamples.size());
    }

    private static void addSamplesForNetworkSample(NetworkSample networkSample, Map<VariableReference, List<?>> samplesByVariable) {
        for (VariableReference variableReference : networkSample.getVariableReferences()) {
            NetworkSamples.addSampleForVariable(variableReference, networkSample.get(variableReference), samplesByVariable);
        }
    }

    private static <T> void addSampleForVariable(VariableReference variableReference, T value, Map<VariableReference, List<?>> samples) {
        List samplesForVertex = samples.computeIfAbsent(variableReference, v -> new ArrayList());
        samplesForVertex.add(value);
    }

    public int size() {
        return this.size;
    }

    public <T> Samples<T> get(Variable<T, ?> variable) {
        if (variable instanceof DoubleVertex) {
            return this.getDoubleTensorSamples(variable.getReference());
        }
        if (variable instanceof IntegerVertex) {
            return this.getIntegerTensorSamples(variable.getReference());
        }
        if (variable instanceof BooleanVertex) {
            return this.getBooleanTensorSamples(variable.getReference());
        }
        return this.get(variable.getReference());
    }

    public <T> Samples<T> get(VariableReference variableReference) {
        return new Samples(this.samplesByVariable.get(variableReference));
    }

    public DoubleVertexSamples getDoubleTensorSamples(Variable<DoubleTensor, ?> variable) {
        return this.getDoubleTensorSamples(variable.getReference());
    }

    public DoubleVertexSamples getDoubleTensorSamples(VariableReference variableReference) {
        return new DoubleVertexSamples(this.samplesByVariable.get(variableReference));
    }

    public IntegerVertexSamples getIntegerTensorSamples(Variable<IntegerTensor, ?> variable) {
        return this.getIntegerTensorSamples(variable.getReference());
    }

    public IntegerVertexSamples getIntegerTensorSamples(VariableReference variableReference) {
        return new IntegerVertexSamples(this.samplesByVariable.get(variableReference));
    }

    public BooleanVertexSamples getBooleanTensorSamples(Variable<BooleanTensor, ?> variable) {
        return this.getBooleanTensorSamples(variable.getReference());
    }

    public BooleanVertexSamples getBooleanTensorSamples(VariableReference variableReference) {
        return new BooleanVertexSamples(this.samplesByVariable.get(variableReference));
    }

    public NetworkSamples drop(int dropCount) {
        Preconditions.checkArgument((dropCount >= 0 ? 1 : 0) != 0, (String)"Cannot drop %s samples. Drop count must be positive.", (int)dropCount);
        if (dropCount == 0) {
            return this;
        }
        Map<VariableReference, List> withSamplesDropped = this.samplesByVariable.entrySet().parallelStream().collect(Collectors.toMap(Map.Entry::getKey, e -> ((List)e.getValue()).subList(dropCount, this.size)));
        List<Double> withLogProbsDropped = this.logOfMasterPForEachSample.subList(dropCount, this.size);
        return new NetworkSamples(withSamplesDropped, withLogProbsDropped, this.size - dropCount);
    }

    public NetworkSamples downSample(int downSampleInterval) {
        Preconditions.checkArgument((downSampleInterval > 0 ? 1 : 0) != 0, (String)"Down sample interval of %s is invalid. Sample interval must be positive.", (int)downSampleInterval);
        Map<VariableReference, List> withSamplesDownSampled = this.samplesByVariable.entrySet().parallelStream().collect(Collectors.toMap(Map.Entry::getKey, e -> NetworkSamples.downSample((List)e.getValue(), downSampleInterval)));
        List<Double> withLogProbsDownSampled = NetworkSamples.downSample(this.logOfMasterPForEachSample, downSampleInterval);
        return new NetworkSamples(withSamplesDownSampled, withLogProbsDownSampled, this.size / downSampleInterval);
    }

    private static <T> List<T> downSample(List<T> samples, int downSampleInterval) {
        ArrayList<T> downSampled = new ArrayList<T>();
        int i = 0;
        for (T sample : samples) {
            if (i % downSampleInterval == 0) {
                downSampled.add(sample);
            }
            ++i;
        }
        return downSampled;
    }

    public double probability(Function<NetworkState, Boolean> predicate) {
        List<NetworkState> networkStates = this.toNetworkStates();
        long trueCount = networkStates.parallelStream().filter(predicate::apply).count();
        return (double)trueCount / (double)networkStates.size();
    }

    public NetworkState getNetworkState(int sample) {
        return new SamplesBackedNetworkState(this.samplesByVariable, sample);
    }

    public double getLogOfMasterP(int sample) {
        return this.logOfMasterPForEachSample.get(sample);
    }

    public List<NetworkState> toNetworkStates() {
        ArrayList<NetworkState> states = new ArrayList<NetworkState>(this.size);
        for (int i = 0; i < this.size; ++i) {
            states.add(this.getNetworkState(i));
        }
        return states;
    }

    public NetworkState getMostProbableState() {
        Integer sampleNumberWithMostProbableState = IntStream.range(0, this.logOfMasterPForEachSample.size()).boxed().max(Comparator.comparing(i -> (double)this.logOfMasterPForEachSample.get((int)i))).orElse(0);
        log.debug(String.format("Most probable state is %d: %.4f", sampleNumberWithMostProbableState, this.logOfMasterPForEachSample.get(sampleNumberWithMostProbableState)));
        return this.getNetworkState(sampleNumberWithMostProbableState);
    }

    private static class SamplesBackedNetworkState
    implements NetworkState {
        private final Map<VariableReference, ? extends List> samplesByVariable;
        private final int index;

        public SamplesBackedNetworkState(Map<VariableReference, ? extends List> samplesByVariable, int index) {
            this.samplesByVariable = samplesByVariable;
            this.index = index;
        }

        @Override
        public <T> T get(Variable<T, ?> variable) {
            return (T)this.samplesByVariable.get(variable.getReference()).get(this.index);
        }

        @Override
        public <T> T get(VariableReference variableReference) {
            return (T)this.samplesByVariable.get(variableReference).get(this.index);
        }

        @Override
        public Set<VariableReference> getVariableReferences() {
            return new HashSet<VariableReference>(this.samplesByVariable.keySet());
        }
    }
}

