/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification.utils;

import java.io.IOException;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.IndexableFieldType;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.GroupingSearch;
import org.apache.lucene.search.grouping.TopGroups;
import org.apache.lucene.store.Directory;

public class DatasetSplitter {
    private final double crossValidationRatio;
    private final double testRatio;

    public DatasetSplitter(double testRatio, double crossValidationRatio) {
        this.crossValidationRatio = crossValidationRatio;
        this.testRatio = testRatio;
    }

    public void split(IndexReader originalIndex, Directory trainingIndex, Directory testIndex, Directory crossValidationIndex, Analyzer analyzer, boolean termVectors, String classFieldName, String ... fieldNames) throws IOException {
        IndexWriter testWriter = new IndexWriter(testIndex, new IndexWriterConfig(analyzer));
        IndexWriter cvWriter = new IndexWriter(crossValidationIndex, new IndexWriterConfig(analyzer));
        IndexWriter trainingWriter = new IndexWriter(trainingIndex, new IndexWriterConfig(analyzer));
        int noOfClasses = 0;
        for (LeafReaderContext leave : originalIndex.leaves()) {
            long valueCount = 0L;
            SortedDocValues classValues = leave.reader().getSortedDocValues(classFieldName);
            if (classValues != null) {
                valueCount = classValues.getValueCount();
            } else {
                SortedSetDocValues sortedSetDocValues = leave.reader().getSortedSetDocValues(classFieldName);
                if (sortedSetDocValues != null) {
                    valueCount = sortedSetDocValues.getValueCount();
                }
            }
            if (classValues == null) {
                noOfClasses = (int)((long)noOfClasses + leave.reader().terms(classFieldName).size());
            }
            noOfClasses = (int)((long)noOfClasses + valueCount);
        }
        try {
            IndexSearcher indexSearcher = new IndexSearcher(originalIndex);
            GroupingSearch gs = new GroupingSearch(classFieldName);
            gs.setGroupSort(Sort.INDEXORDER);
            gs.setSortWithinGroup(Sort.INDEXORDER);
            gs.setAllGroups(true);
            gs.setGroupDocsLimit(originalIndex.maxDoc());
            TopGroups topGroups = gs.search(indexSearcher, (Query)new MatchAllDocsQuery(), 0, noOfClasses);
            FieldType ft = new FieldType((IndexableFieldType)TextField.TYPE_STORED);
            if (termVectors) {
                ft.setStoreTermVectors(true);
                ft.setStoreTermVectorOffsets(true);
                ft.setStoreTermVectorPositions(true);
            }
            int b = 0;
            for (GroupDocs group : topGroups.groups) {
                long totalHits = group.totalHits;
                double testSize = (double)totalHits * this.testRatio;
                int tc = 0;
                double cvSize = (double)totalHits * this.crossValidationRatio;
                int cvc = 0;
                for (ScoreDoc scoreDoc : group.scoreDocs) {
                    Document doc = this.createNewDoc(originalIndex, ft, scoreDoc, fieldNames);
                    if (b % 2 == 0 && (double)tc < testSize) {
                        testWriter.addDocument((Iterable)doc);
                        ++tc;
                    } else if ((double)cvc < cvSize) {
                        cvWriter.addDocument((Iterable)doc);
                        ++cvc;
                    } else {
                        trainingWriter.addDocument((Iterable)doc);
                    }
                    ++b;
                }
            }
            testWriter.commit();
            cvWriter.commit();
            trainingWriter.commit();
            testWriter.forceMerge(3);
            cvWriter.forceMerge(3);
            trainingWriter.forceMerge(3);
        }
        catch (Exception e) {
            throw new IOException(e);
        }
        finally {
            testWriter.close();
            cvWriter.close();
            trainingWriter.close();
            originalIndex.close();
        }
    }

    private Document createNewDoc(IndexReader originalIndex, FieldType ft, ScoreDoc scoreDoc, String[] fieldNames) throws IOException {
        Document doc = new Document();
        Document document = originalIndex.document(scoreDoc.doc);
        if (fieldNames != null && fieldNames.length > 0) {
            for (String fieldName : fieldNames) {
                IndexableField field = document.getField(fieldName);
                if (field == null) continue;
                doc.add((IndexableField)new Field(fieldName, (CharSequence)field.stringValue(), (IndexableFieldType)ft));
            }
        } else {
            for (IndexableField field : document.getFields()) {
                if (field.readerValue() != null) {
                    doc.add((IndexableField)new Field(field.name(), field.readerValue(), (IndexableFieldType)ft));
                    continue;
                }
                if (field.binaryValue() != null) {
                    doc.add((IndexableField)new Field(field.name(), field.binaryValue(), (IndexableFieldType)ft));
                    continue;
                }
                if (field.stringValue() != null) {
                    doc.add((IndexableField)new Field(field.name(), (CharSequence)field.stringValue(), (IndexableFieldType)ft));
                    continue;
                }
                if (field.numericValue() == null) continue;
                doc.add((IndexableField)new Field(field.name(), (CharSequence)field.numericValue().toString(), (IndexableFieldType)ft));
            }
        }
        return doc;
    }
}

