/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.experiment.dataset.split;

import java.util.Iterator;
import java.util.Map;
import org.openimaj.data.RandomData;
import org.openimaj.data.dataset.GroupedDataset;
import org.openimaj.data.dataset.ListBackedDataset;
import org.openimaj.data.dataset.ListDataset;
import org.openimaj.data.dataset.MapBackedDataset;
import org.openimaj.experiment.dataset.split.TestSplitProvider;
import org.openimaj.experiment.dataset.split.TrainSplitProvider;
import org.openimaj.experiment.dataset.split.ValidateSplitProvider;
import org.openimaj.experiment.validation.ValidationData;
import org.openimaj.experiment.validation.cross.CrossValidationIterable;

public class GroupedRandomSplitter<KEY, INSTANCE>
implements TrainSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>,
TestSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>,
ValidateSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
    private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset;
    private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> trainingSplit;
    private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> validationSplit;
    private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> testingSplit;
    private int numTraining;
    private int numValidation;
    private int numTesting;

    public GroupedRandomSplitter(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, int numTraining, int numValidation, int numTesting) {
        this.dataset = dataset;
        this.numTraining = numTraining;
        this.numValidation = numValidation;
        this.numTesting = numTesting;
        this.recomputeSubsets();
    }

    public void recomputeSubsets() {
        this.trainingSplit = new MapBackedDataset();
        this.validationSplit = new MapBackedDataset();
        this.testingSplit = new MapBackedDataset();
        for (Map.Entry e : this.dataset.entrySet()) {
            Object key = e.getKey();
            ListDataset allData = (ListDataset)e.getValue();
            if (allData.size() < this.numTraining + 1) {
                throw new RuntimeException("Too many training examples; none would be available for validation or testing.");
            }
            if (allData.size() < this.numTraining + this.numValidation + 1) {
                throw new RuntimeException("Too many training and validation instances; none would be available for testing.");
            }
            int[] ids = RandomData.getUniqueRandomInts((int)Math.min(this.numTraining + this.numValidation + this.numTesting, allData.size()), (int)0, (int)allData.size());
            ListBackedDataset train = new ListBackedDataset();
            for (int i = 0; i < this.numTraining; ++i) {
                train.add(allData.get(ids[i]));
            }
            this.trainingSplit.put(key, (Object)train);
            ListBackedDataset valid = new ListBackedDataset();
            for (int i = this.numTraining; i < this.numTraining + this.numValidation; ++i) {
                valid.add(allData.get(ids[i]));
            }
            this.validationSplit.put(key, (Object)valid);
            ListBackedDataset test = new ListBackedDataset();
            for (int i = this.numTraining + this.numValidation; i < ids.length; ++i) {
                test.add(allData.get(ids[i]));
            }
            this.testingSplit.put(key, (Object)test);
        }
    }

    @Override
    public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTestDataset() {
        return this.testingSplit;
    }

    @Override
    public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTrainingDataset() {
        return this.trainingSplit;
    }

    @Override
    public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getValidationDataset() {
        return this.validationSplit;
    }

    public static <KEY, INSTANCE> CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createCrossValidationData(final GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, final int numTraining, final int numValidation, final int numIterations) {
        return new CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(){
            private GroupedRandomSplitter<KEY, INSTANCE> splits;
            {
                this.splits = new GroupedRandomSplitter(dataset, numTraining, numValidation, 0);
            }

            @Override
            public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() {
                return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>(){
                    int current = 0;

                    @Override
                    public boolean hasNext() {
                        return this.current < numIterations;
                    }

                    @Override
                    public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() {
                        splits.recomputeSubsets();
                        ++this.current;
                        return new ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(){

                            @Override
                            public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTrainingDataset() {
                                return splits.getTrainingDataset();
                            }

                            @Override
                            public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getValidationDataset() {
                                return splits.getValidationDataset();
                            }
                        };
                    }

                    @Override
                    public void remove() {
                        throw new UnsupportedOperationException("Removal not supported");
                    }
                };
            }

            @Override
            public int numberIterations() {
                return numIterations;
            }
        };
    }
}

