/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.classification.xgboost;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import org.tribuo.Example;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Prediction;
import org.tribuo.classification.Label;
import org.tribuo.common.xgboost.XGBoostOutputConverter;

public final class XGBoostClassificationConverter
implements XGBoostOutputConverter<Label> {
    private static final long serialVersionUID = 1L;

    public boolean generatesProbabilities() {
        return true;
    }

    public Prediction<Label> convertOutput(ImmutableOutputInfo<Label> info, List<float[]> probabilitiesList, int numValidFeatures, Example<Label> example) {
        if (probabilitiesList.size() != 1) {
            throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output.");
        }
        double maxScore = Double.NEGATIVE_INFINITY;
        Label maxLabel = null;
        LinkedHashMap<String, Label> probMap = new LinkedHashMap<String, Label>();
        float[] probabilities = probabilitiesList.get(0);
        for (int i = 0; i < probabilities.length; ++i) {
            String name = ((Label)info.getOutput(i)).getLabel();
            Label label = new Label(name, (double)probabilities[i]);
            probMap.put(name, label);
            if (!(label.getScore() > maxScore)) continue;
            maxScore = label.getScore();
            maxLabel = label;
        }
        return new Prediction(maxLabel, probMap, numValidFeatures, example, true);
    }

    public List<Prediction<Label>> convertBatchOutput(ImmutableOutputInfo<Label> info, List<float[][]> probabilitiesList, int[] numValidFeatures, Example<Label>[] examples) {
        if (probabilitiesList.size() != 1) {
            throw new IllegalArgumentException("XGBoostClassificationConverter only expects a single model output.");
        }
        float[][] probabilities = probabilitiesList.get(0);
        ArrayList<Prediction<Label>> predictions = new ArrayList<Prediction<Label>>();
        for (int i = 0; i < probabilities.length; ++i) {
            double maxScore = Double.NEGATIVE_INFINITY;
            Label maxLabel = null;
            LinkedHashMap<String, Label> probMap = new LinkedHashMap<String, Label>();
            for (int j = 0; j < probabilities[i].length; ++j) {
                String name = ((Label)info.getOutput(j)).getLabel();
                Label label = new Label(name, (double)probabilities[i][j]);
                probMap.put(name, label);
                if (!(label.getScore() > maxScore)) continue;
                maxScore = label.getScore();
                maxLabel = label;
            }
            predictions.add((Prediction<Label>)new Prediction(maxLabel, probMap, numValidFeatures[i], examples[i], true));
        }
        return predictions;
    }
}

