package weka.classifiers.meta;

import java.util.Collections;
import java.util.Enumeration;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Vector;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.IterativeClassifier;
import weka.classifiers.RandomizableClassifier;
import weka.classifiers.evaluation.Evaluation;
import weka.classifiers.evaluation.EvaluationMetricHelper;
import weka.core.AdditionalMeasureProducer;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.OptionHandler;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;
import weka.core.json.JSONInstances;

/* loaded from: input_file:weka/classifiers/meta/IterativeClassifierOptimizer.class */
public class IterativeClassifierOptimizer extends RandomizableClassifier implements AdditionalMeasureProducer {
    private static final long serialVersionUID = -3665485256313525864L;
    public static Tag[] TAGS_EVAL;
    protected int m_bestNumIts;
    protected IterativeClassifier m_IterativeClassifier = new LogitBoost();
    protected int m_NumFolds = 10;
    protected int m_NumRuns = 1;
    protected int m_StepSize = 1;
    protected boolean m_UseAverage = false;
    protected int m_lookAheadIterations = 50;
    protected String m_evalMetric = "rmse";
    protected int m_classValueIndex = -1;
    protected double[] m_thresholds = null;
    protected double m_bestResult = Double.MAX_VALUE;
    protected int m_numThreads = 1;
    protected int m_poolSize = 1;

    public String globalInfo() {
        return "Optimizes the number of iterations of the given iterative classifier using cross-validation.";
    }

    protected String defaultIterativeClassifierString() {
        return "weka.classifiers.meta.LogitBoost";
    }

    public String useAverageTipText() {
        return "If true, average estimates are used instead of one estimate from pooled predictions.";
    }

    public boolean getUseAverage() {
        return this.m_UseAverage;
    }

    public void setUseAverage(boolean z) {
        this.m_UseAverage = z;
    }

    public String numThreadsTipText() {
        return "The number of threads to use, which should be >= size of thread pool.";
    }

    public int getNumThreads() {
        return this.m_numThreads;
    }

    public void setNumThreads(int i) {
        this.m_numThreads = i;
    }

    public String poolSizeTipText() {
        return "The size of the thread pool, for example, the number of cores in the CPU.";
    }

    public int getPoolSize() {
        return this.m_poolSize;
    }

    public void setPoolSize(int i) {
        this.m_poolSize = i;
    }

    public String stepSizeTipText() {
        return "Step size for the evaluation, if evaluation is time consuming.";
    }

    public int getStepSize() {
        return this.m_StepSize;
    }

    public void setStepSize(int i) {
        this.m_StepSize = i;
    }

    public String numRunsTipText() {
        return "Number of runs for cross-validation.";
    }

    public int getNumRuns() {
        return this.m_NumRuns;
    }

    public void setNumRuns(int i) {
        this.m_NumRuns = i;
    }

    public String numFoldsTipText() {
        return "Number of folds for cross-validation.";
    }

    public int getNumFolds() {
        return this.m_NumFolds;
    }

    public void setNumFolds(int i) {
        this.m_NumFolds = i;
    }

    public String lookAheadIterationsTipText() {
        return "The number of iterations to look ahead for to find a better optimum.";
    }

    public int getLookAheadIterations() {
        return this.m_lookAheadIterations;
    }

    public void setLookAheadIterations(int i) {
        this.m_lookAheadIterations = i;
    }

    @Override // weka.classifiers.Classifier
    public void buildClassifier(Instances instances) throws Exception {
        boolean z;
        double d;
        if (this.m_IterativeClassifier == null) {
            throw new Exception("A base classifier has not been specified!");
        }
        getCapabilities().testWithFail(instances);
        Random random = new Random(this.m_Seed);
        Instances instances2 = new Instances(instances);
        instances2.deleteWithMissingClass();
        if (instances2.numInstances() < this.m_NumFolds) {
            System.err.println("WARNING: reducing number of folds to number of instances in IterativeClassifierOptimizer");
            this.m_NumFolds = instances2.numInstances();
        }
        Instances[][] instancesArr = new Instances[this.m_NumRuns][this.m_NumFolds];
        Instances[][] instancesArr2 = new Instances[this.m_NumRuns][this.m_NumFolds];
        final IterativeClassifier[][] iterativeClassifierArr = new IterativeClassifier[this.m_NumRuns][this.m_NumFolds];
        for (int i = 0; i < this.m_NumRuns; i++) {
            instances2.randomize(random);
            if (instances2.classAttribute().isNominal()) {
                instances2.stratify(this.m_NumFolds);
            }
            for (int i2 = 0; i2 < this.m_NumFolds; i2++) {
                instancesArr[i][i2] = instances2.trainCV(this.m_NumFolds, i2, random);
                instancesArr2[i][i2] = instances2.testCV(this.m_NumFolds, i2);
                iterativeClassifierArr[i][i2] = (IterativeClassifier) AbstractClassifier.makeCopy(this.m_IterativeClassifier);
                iterativeClassifierArr[i][i2].initializeClassifier(instancesArr[i][i2]);
            }
        }
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.m_poolSize);
        EvaluationMetricHelper evaluationMetricHelper = new EvaluationMetricHelper(new Evaluation(instances2));
        boolean metricIsMaximisable = evaluationMetricHelper.metricIsMaximisable(this.m_evalMetric);
        if (metricIsMaximisable) {
            this.m_bestResult = Double.MIN_VALUE;
        } else {
            this.m_bestResult = Double.MAX_VALUE;
        }
        this.m_thresholds = null;
        int i3 = 0;
        this.m_bestNumIts = 0;
        int i4 = -1;
        while (true) {
            if (i3 % this.m_StepSize == 0) {
                double d2 = 0.0d;
                double[] dArr = null;
                if (this.m_UseAverage) {
                    for (int i5 = 0; i5 < this.m_NumRuns; i5++) {
                        for (int i6 = 0; i6 < this.m_NumFolds; i6++) {
                            Evaluation evaluation = new Evaluation(instancesArr[i5][i6]);
                            evaluationMetricHelper.setEvaluation(evaluation);
                            evaluation.evaluateModel(iterativeClassifierArr[i5][i6], instancesArr2[i5][i6], new Object[0]);
                            d2 += getClassValueIndex() >= 0 ? evaluationMetricHelper.getNamedMetric(this.m_evalMetric, getClassValueIndex()) : evaluationMetricHelper.getNamedMetric(this.m_evalMetric, new int[0]);
                            double[] namedMetricThresholds = evaluationMetricHelper.getNamedMetricThresholds(this.m_evalMetric);
                            if (namedMetricThresholds != null) {
                                if (dArr == null) {
                                    dArr = new double[instances2.numClasses()];
                                }
                                for (int i7 = 0; i7 < namedMetricThresholds.length; i7++) {
                                    double[] dArr2 = dArr;
                                    int i8 = i7;
                                    dArr2[i8] = dArr2[i8] + namedMetricThresholds[i7];
                                }
                            }
                        }
                    }
                    d = d2 / (this.m_NumFolds * this.m_NumRuns);
                    if (dArr != null) {
                        for (int i9 = 0; i9 < dArr.length; i9++) {
                            double[] dArr3 = dArr;
                            int i10 = i9;
                            dArr3[i10] = dArr3[i10] / (this.m_NumRuns * this.m_NumFolds);
                        }
                    }
                } else {
                    Evaluation evaluation2 = new Evaluation(instances2);
                    evaluationMetricHelper.setEvaluation(evaluation2);
                    for (int i11 = 0; i11 < this.m_NumRuns; i11++) {
                        for (int i12 = 0; i12 < this.m_NumFolds; i12++) {
                            evaluation2.evaluateModel(iterativeClassifierArr[i11][i12], instancesArr2[i11][i12], new Object[0]);
                        }
                    }
                    d = getClassValueIndex() >= 0 ? evaluationMetricHelper.getNamedMetric(this.m_evalMetric, getClassValueIndex()) : evaluationMetricHelper.getNamedMetric(this.m_evalMetric, new int[0]);
                    dArr = evaluationMetricHelper.getNamedMetricThresholds(this.m_evalMetric);
                }
                if (this.m_Debug) {
                    System.err.println("Iteration: " + i3 + " Measure: " + d);
                    if (dArr != null) {
                        System.err.print("Thresholds:");
                        for (double d3 : dArr) {
                            System.err.print(" " + d3);
                        }
                        System.err.println();
                    }
                }
                if ((metricIsMaximisable ? this.m_bestResult - d : d - this.m_bestResult) < 0.0d) {
                    this.m_bestResult = d;
                    this.m_bestNumIts = i3;
                    this.m_thresholds = dArr;
                    i4 = -1;
                }
            }
            i4++;
            i3++;
            if (i4 >= this.m_lookAheadIterations) {
                break;
            }
            int i13 = this.m_NumRuns * this.m_NumFolds;
            final int i14 = this.m_NumFolds;
            int i15 = i13 / this.m_numThreads;
            HashSet hashSet = new HashSet();
            int i16 = 0;
            while (i16 < this.m_numThreads) {
                final int i17 = i16 * i15;
                final int i18 = i16 < this.m_numThreads - 1 ? i17 + i15 : i13;
                hashSet.add(newFixedThreadPool.submit(new Callable<Boolean>() { // from class: weka.classifiers.meta.IterativeClassifierOptimizer.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Boolean call() throws Exception {
                        for (int i19 = i17; i19 < i18; i19++) {
                            if (!iterativeClassifierArr[i19 / i14][i19 % i14].next()) {
                                if (IterativeClassifierOptimizer.this.m_Debug) {
                                    System.err.println("Classifier failed to iterate in cross-validation.");
                                }
                                return false;
                            }
                        }
                        return true;
                    }
                }));
                i16++;
            }
            try {
                z = false;
                Iterator it = hashSet.iterator();
                while (true) {
                    if (it.hasNext()) {
                        if (!((Boolean) ((Future) it.next()).get()).booleanValue()) {
                            z = true;
                            break;
                        }
                    } else {
                        break;
                    }
                }
            } catch (Exception e) {
                System.out.println("Classifiers could not be generated.");
                e.printStackTrace();
            }
            if (z) {
                break;
            }
        }
        this.m_IterativeClassifier.initializeClassifier(instances);
        int i19 = 0;
        do {
            int i20 = i19;
            i19++;
            if (i20 >= this.m_bestNumIts) {
                break;
            }
        } while (this.m_IterativeClassifier.next());
        this.m_IterativeClassifier.done();
        newFixedThreadPool.shutdown();
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier
    public double[] distributionForInstance(Instance instance) throws Exception {
        if (this.m_thresholds == null) {
            return this.m_IterativeClassifier.distributionForInstance(instance);
        }
        double[] distributionForInstance = this.m_IterativeClassifier.distributionForInstance(instance);
        double[] dArr = new double[distributionForInstance.length];
        for (int i = 0; i < distributionForInstance.length; i++) {
            if (distributionForInstance[i] >= this.m_thresholds[i]) {
                dArr[i] = 1.0d;
            }
        }
        Utils.normalize(dArr);
        return dArr;
    }

    public String toString() {
        if (this.m_IterativeClassifier == null) {
            return "No classifier built yet.";
        }
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("Best value found: " + this.m_bestResult + "\n");
        stringBuffer.append("Best number of iterations found: " + this.m_bestNumIts + "\n\n");
        if (this.m_thresholds != null) {
            stringBuffer.append("Thresholds found: ");
            for (int i = 0; i < this.m_thresholds.length; i++) {
                stringBuffer.append(this.m_thresholds[i] + " ");
            }
        }
        stringBuffer.append("\n\n");
        stringBuffer.append(this.m_IterativeClassifier.toString());
        return stringBuffer.toString();
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public Enumeration<Option> listOptions() {
        Vector vector = new Vector(7);
        vector.addElement(new Option("\tIf set, average estimate is used rather than one estimate from pooled predictions.\n", "A", 0, "-A"));
        vector.addElement(new Option("\t" + lookAheadIterationsTipText() + "\n\t(default 50)", "L", 1, "-L <num>"));
        vector.addElement(new Option("\t" + poolSizeTipText() + "\n\t(default 1)", "P", 1, "-P <int>"));
        vector.addElement(new Option("\t" + numThreadsTipText() + "\n\t(default 1)", "E", 1, "-E <int>"));
        vector.addElement(new Option("\t" + stepSizeTipText() + "\n\t(default 1)", "I", 1, "-I <num>"));
        vector.addElement(new Option("\tNumber of folds for cross-validation.\n\t(default 10)", "F", 1, "-F <num>"));
        vector.addElement(new Option("\tNumber of runs for cross-validation.\n\t(default 1)", "R", 1, "-R <num>"));
        vector.addElement(new Option("\tFull name of base classifier.\n\t(default: " + defaultIterativeClassifierString() + ")", "W", 1, "-W"));
        List<String> allMetricNames = EvaluationMetricHelper.getAllMetricNames();
        StringBuilder sb = new StringBuilder();
        int i = 0;
        for (String str : allMetricNames) {
            sb.append(str.toLowerCase()).append(",");
            i += str.length();
            if (i >= 60) {
                sb.append("\n\t");
                i = 0;
            }
        }
        vector.addElement(new Option("\tEvaluation metric to optimise (default rmse). Available metrics:\n\t" + sb.substring(0, sb.length() - 1), "metric", 1, "-metric <name>"));
        vector.addElement(new Option("\tClass value index to optimise. Ignored for all but information-retrieval\n\ttype metrics (such as roc area). If unspecified (or a negative value is supplied),\n\tand an information-retrieval metric is specified, then the class-weighted average\n\tmetric used. (default -1)", "class-value-index", 1, "-class-value-index <0-based index>"));
        vector.addAll(Collections.list(super.listOptions()));
        vector.addElement(new Option("", "", 0, "\nOptions specific to classifier " + this.m_IterativeClassifier.getClass().getName() + JSONInstances.SPARSE_SEPARATOR));
        vector.addAll(Collections.list(((OptionHandler) this.m_IterativeClassifier).listOptions()));
        return vector.elements();
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        super.setOptions(strArr);
        setUseAverage(Utils.getFlag('A', strArr));
        String option = Utils.getOption('L', strArr);
        if (option.length() != 0) {
            setLookAheadIterations(Integer.parseInt(option));
        } else {
            setLookAheadIterations(50);
        }
        String option2 = Utils.getOption('P', strArr);
        if (option2.length() != 0) {
            setPoolSize(Integer.parseInt(option2));
        } else {
            setPoolSize(1);
        }
        String option3 = Utils.getOption('E', strArr);
        if (option3.length() != 0) {
            setNumThreads(Integer.parseInt(option3));
        } else {
            setNumThreads(1);
        }
        String option4 = Utils.getOption('I', strArr);
        if (option4.length() != 0) {
            setStepSize(Integer.parseInt(option4));
        } else {
            setStepSize(1);
        }
        String option5 = Utils.getOption('F', strArr);
        if (option5.length() != 0) {
            setNumFolds(Integer.parseInt(option5));
        } else {
            setNumFolds(10);
        }
        String option6 = Utils.getOption('R', strArr);
        if (option6.length() != 0) {
            setNumRuns(Integer.parseInt(option6));
        } else {
            setNumRuns(1);
        }
        String option7 = Utils.getOption("metric", strArr);
        if (option7.length() > 0) {
            boolean z = false;
            int i = 0;
            while (true) {
                if (i >= TAGS_EVAL.length) {
                    break;
                }
                if (TAGS_EVAL[i].getIDStr().equalsIgnoreCase(option7)) {
                    setEvaluationMetric(new SelectedTag(i, TAGS_EVAL));
                    z = true;
                    break;
                }
                i++;
            }
            if (!z) {
                throw new Exception("Unknown evaluation metric: " + option7);
            }
        }
        String option8 = Utils.getOption("class-value-index", strArr);
        if (option8.length() > 0) {
            setClassValueIndex(Integer.parseInt(option8));
        } else {
            setClassValueIndex(-1);
        }
        String option9 = Utils.getOption('W', strArr);
        if (option9.length() > 0) {
            setIterativeClassifier(getIterativeClassifier(option9, Utils.partitionOptions(strArr)));
        } else {
            setIterativeClassifier(getIterativeClassifier(defaultIterativeClassifierString(), Utils.partitionOptions(strArr)));
        }
    }

    protected IterativeClassifier getIterativeClassifier(String str, String[] strArr) throws Exception {
        Classifier forName = AbstractClassifier.forName(str, strArr);
        if (forName instanceof IterativeClassifier) {
            return (IterativeClassifier) forName;
        }
        throw new IllegalArgumentException(str + " is not an IterativeClassifier.");
    }

    @Override // weka.classifiers.RandomizableClassifier, weka.classifiers.AbstractClassifier, weka.core.OptionHandler
    public String[] getOptions() {
        Vector vector = new Vector();
        if (getUseAverage()) {
            vector.add("-A");
        }
        vector.add("-W");
        vector.add(getIterativeClassifier().getClass().getName());
        vector.add("-L");
        vector.add("" + getLookAheadIterations());
        vector.add("-P");
        vector.add("" + getPoolSize());
        vector.add("-E");
        vector.add("" + getNumThreads());
        vector.add("-I");
        vector.add("" + getStepSize());
        vector.add("-F");
        vector.add("" + getNumFolds());
        vector.add("-R");
        vector.add("" + getNumRuns());
        vector.add("-metric");
        vector.add(getEvaluationMetric().getSelectedTag().getIDStr());
        if (getClassValueIndex() >= 0) {
            vector.add("-class-value-index");
            vector.add("" + getClassValueIndex());
        }
        Collections.addAll(vector, super.getOptions());
        String[] options = ((OptionHandler) this.m_IterativeClassifier).getOptions();
        if (options.length > 0) {
            vector.add("--");
            Collections.addAll(vector, options);
        }
        return (String[]) vector.toArray(new String[0]);
    }

    public String evaluationMetricTipText() {
        return "The evaluation metric to use";
    }

    public void setEvaluationMetric(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_EVAL) {
            this.m_evalMetric = selectedTag.getSelectedTag().getIDStr();
        }
    }

    public SelectedTag getEvaluationMetric() {
        for (int i = 0; i < TAGS_EVAL.length; i++) {
            if (TAGS_EVAL[i].getIDStr().equalsIgnoreCase(this.m_evalMetric)) {
                return new SelectedTag(i, TAGS_EVAL);
            }
        }
        return new SelectedTag(12, TAGS_EVAL);
    }

    public String classValueIndexTipText() {
        return "The class value index to use with information retrieval type metrics. A value < 0 indicates to use the class weighted average version of the metric.";
    }

    public void setClassValueIndex(int i) {
        this.m_classValueIndex = i;
    }

    public int getClassValueIndex() {
        return this.m_classValueIndex;
    }

    public String iterativeClassifierTipText() {
        return "The iterative classifier to be optimized.";
    }

    @Override // weka.classifiers.AbstractClassifier, weka.classifiers.Classifier, weka.core.CapabilitiesHandler
    public Capabilities getCapabilities() {
        Capabilities capabilities;
        if (getIterativeClassifier() != null) {
            capabilities = getIterativeClassifier().getCapabilities();
        } else {
            capabilities = new Capabilities(this);
            capabilities.disableAll();
        }
        for (Capabilities.Capability capability : Capabilities.Capability.values()) {
            capabilities.enableDependency(capability);
        }
        capabilities.setOwner(this);
        return capabilities;
    }

    public void setIterativeClassifier(IterativeClassifier iterativeClassifier) {
        this.m_IterativeClassifier = iterativeClassifier;
    }

    public IterativeClassifier getIterativeClassifier() {
        return this.m_IterativeClassifier;
    }

    protected String getIterativeClassifierSpec() {
        IterativeClassifier iterativeClassifier = getIterativeClassifier();
        return iterativeClassifier.getClass().getName() + " " + Utils.joinOptions(((OptionHandler) iterativeClassifier).getOptions());
    }

    @Override // weka.classifiers.AbstractClassifier, weka.core.RevisionHandler
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 10649 $");
    }

    public double measureBestNumIts() {
        return this.m_bestNumIts;
    }

    public double measureBestVal() {
        return this.m_bestResult;
    }

    @Override // weka.core.AdditionalMeasureProducer
    public Enumeration<String> enumerateMeasures() {
        Vector vector = new Vector(2);
        vector.addElement("measureBestNumIts");
        vector.addElement("measureBestVal");
        return vector.elements();
    }

    @Override // weka.core.AdditionalMeasureProducer
    public double getMeasure(String str) {
        if (str.compareToIgnoreCase("measureBestNumIts") == 0) {
            return measureBestNumIts();
        }
        if (str.compareToIgnoreCase("measureBestVal") == 0) {
            return measureBestVal();
        }
        throw new IllegalArgumentException(str + " not supported (IterativeClassifierOptimizer)");
    }

    public static void main(String[] strArr) {
        runClassifier(new IterativeClassifierOptimizer(), strArr);
    }

    static {
        List<String> allMetricNames = EvaluationMetricHelper.getAllMetricNames();
        TAGS_EVAL = new Tag[allMetricNames.size()];
        for (int i = 0; i < allMetricNames.size(); i++) {
            TAGS_EVAL[i] = new Tag(i, allMetricNames.get(i), allMetricNames.get(i), false);
        }
    }
}
