/*
 * Decompiled with CFR 0.152.
 */
package io.improbable.keanu.model;

import io.improbable.keanu.algorithms.NetworkSamples;
import io.improbable.keanu.algorithms.PosteriorSamplingAlgorithm;
import io.improbable.keanu.model.ModelFitter;
import io.improbable.keanu.model.ModelGraph;
import io.improbable.keanu.network.KeanuProbabilisticModel;
import io.improbable.keanu.network.NetworkState;
import java.util.function.Function;

public class SamplingModelFitter
implements ModelFitter {
    private final Function<KeanuProbabilisticModel, PosteriorSamplingAlgorithm> samplingAlgorithmGenerator;
    private final int sampleCount;
    private NetworkSamples posteriorSamples;

    public SamplingModelFitter(Function<KeanuProbabilisticModel, PosteriorSamplingAlgorithm> samplingAlgorithmGenerator, int sampleCount) {
        this.samplingAlgorithmGenerator = samplingAlgorithmGenerator;
        this.sampleCount = sampleCount;
    }

    @Override
    public void fit(ModelGraph modelGraph) {
        KeanuProbabilisticModel probabilisticModel = new KeanuProbabilisticModel(modelGraph.getBayesianNetwork());
        this.posteriorSamples = this.samplingAlgorithmGenerator.apply(probabilisticModel).getPosteriorSamples(probabilisticModel, this.sampleCount);
        NetworkState mostProbableState = this.posteriorSamples.getMostProbableState();
        modelGraph.getBayesianNetwork().setState(mostProbableState);
    }

    public NetworkSamples getNetworkSamples() {
        return this.posteriorSamples;
    }
}

