/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.network.grouping;

import io.improbable.keanu.algorithms.VariableReference;
import io.improbable.keanu.network.NetworkState;
import io.improbable.keanu.network.SimpleNetworkState;
import io.improbable.keanu.network.grouping.ContinuousPoint;
import io.improbable.keanu.network.grouping.DiscretePoint;
import io.improbable.keanu.network.grouping.continuouspointgroupers.ContinuousPointGrouper;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class NetworkStateGrouper {
    private final ContinuousPointGrouper continuousPointGrouper;

    public NetworkStateGrouper(ContinuousPointGrouper continuousPointGrouper) {
        this.continuousPointGrouper = continuousPointGrouper;
    }

    public List<List<NetworkState>> groupNetworkStates(List<NetworkState> networkStates, List<VariableReference> discreteVertexIds, List<VariableReference> continuousVertexIds) {
        Map<DiscretePoint, List<NetworkState>> statesGroupedByDiscretePoint = networkStates.stream().collect(Collectors.groupingBy(state -> this.toDiscretePoint((NetworkState)state, discreteVertexIds)));
        Map<DiscretePoint, List> continuousPointsGroupedByDiscretePoint = statesGroupedByDiscretePoint.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, e -> this.continuousPointGrouper.groupContinuousPoints(this.toContinuousPoints((List)e.getValue(), continuousVertexIds))));
        return continuousPointsGroupedByDiscretePoint.entrySet().stream().flatMap(e -> this.toListOfNetworkStates((DiscretePoint)e.getKey(), (List)e.getValue(), discreteVertexIds, continuousVertexIds)).collect(Collectors.toList());
    }

    private Stream<List<NetworkState>> toListOfNetworkStates(DiscretePoint discretePoint, List<List<ContinuousPoint>> continuousPoints, List<VariableReference> discreteVertexIds, List<VariableReference> continuousVertexIds) {
        Map<VariableReference, ?> discreteValues = this.fromDiscretePoint(discretePoint, discreteVertexIds);
        return continuousPoints.stream().map(groupedPoints -> this.toNetworkState(discreteValues, (List<ContinuousPoint>)groupedPoints, continuousVertexIds));
    }

    private List<NetworkState> toNetworkState(Map<VariableReference, ?> discreteValues, List<ContinuousPoint> continuousPoints, List<VariableReference> continuousVertexIds) {
        return continuousPoints.stream().map(point -> {
            HashMap<VariableReference, Double> networkState = new HashMap<VariableReference, Double>();
            Map<VariableReference, Double> continuousValues = this.fromContinuousPoint((ContinuousPoint)point, continuousVertexIds);
            networkState.putAll(continuousValues);
            networkState.putAll(discreteValues);
            return new SimpleNetworkState(networkState);
        }).collect(Collectors.toList());
    }

    private DiscretePoint toDiscretePoint(NetworkState networkState, List<VariableReference> discreteVertexIds) {
        Object[] point = new Object[discreteVertexIds.size()];
        for (int vertex = 0; vertex < discreteVertexIds.size(); ++vertex) {
            point[vertex] = networkState.get(discreteVertexIds.get(vertex));
        }
        return new DiscretePoint(point);
    }

    private Map<VariableReference, ?> fromDiscretePoint(DiscretePoint discretePoint, List<VariableReference> discreteVertexIds) {
        HashMap<VariableReference, Object> discreteStates = new HashMap<VariableReference, Object>();
        Object[] discreteValues = discretePoint.getPoint();
        for (int i = 0; i < discreteVertexIds.size(); ++i) {
            discreteStates.put(discreteVertexIds.get(i), discreteValues[i]);
        }
        return discreteStates;
    }

    private List<ContinuousPoint> toContinuousPoints(List<NetworkState> networkStates, List<VariableReference> continuousVertexIds) {
        return networkStates.stream().map(p -> this.toContinuousPoint((NetworkState)p, continuousVertexIds)).collect(Collectors.toList());
    }

    private ContinuousPoint toContinuousPoint(NetworkState networkState, List<VariableReference> continuousVertexIds) {
        double[] point = new double[continuousVertexIds.size()];
        int i = 0;
        for (VariableReference vertexId : continuousVertexIds) {
            point[i] = (Double)networkState.get(vertexId);
            ++i;
        }
        return new ContinuousPoint(point);
    }

    private Map<VariableReference, Double> fromContinuousPoint(ContinuousPoint point, List<VariableReference> continuousVertexIds) {
        HashMap<VariableReference, Double> continuousStates = new HashMap<VariableReference, Double>();
        double[] continuousValues = point.getPoint();
        for (int i = 0; i < continuousVertexIds.size(); ++i) {
            continuousStates.put(continuousVertexIds.get(i), continuousValues[i]);
        }
        return continuousStates;
    }
}

