/*
 * Decompiled with CFR 0.152.
 */
package org.openimaj.ml.annotation.bayes;

import gov.sandia.cognition.learning.algorithm.IncrementalLearner;
import gov.sandia.cognition.learning.algorithm.bayes.VectorNaiveBayesCategorizer;
import gov.sandia.cognition.learning.data.DefaultInputOutputPair;
import gov.sandia.cognition.learning.data.DefaultWeightedValueDiscriminant;
import gov.sandia.cognition.learning.data.InputOutputPair;
import gov.sandia.cognition.math.LogMath;
import gov.sandia.cognition.math.matrix.Vector;
import gov.sandia.cognition.math.matrix.VectorFactory;
import gov.sandia.cognition.math.matrix.Vectorizable;
import gov.sandia.cognition.statistics.SufficientStatistic;
import gov.sandia.cognition.statistics.distribution.UnivariateGaussian;
import gov.sandia.cognition.util.AbstractCloneableSerializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.openimaj.feature.FeatureExtractor;
import org.openimaj.feature.FeatureVector;
import org.openimaj.feature.IdentityFeatureExtractor;
import org.openimaj.ml.annotation.Annotated;
import org.openimaj.ml.annotation.IncrementalAnnotator;
import org.openimaj.ml.annotation.ScoredAnnotation;

public class NaiveBayesAnnotator<OBJECT, ANNOTATION>
extends IncrementalAnnotator<OBJECT, ANNOTATION> {
    private VectorNaiveBayesCategorizer<ANNOTATION, PDF> categorizer;
    private VectorNaiveBayesCategorizer.OnlineLearner<ANNOTATION, PDF> learner;
    private final Mode mode;
    private FeatureExtractor<? extends FeatureVector, OBJECT> extractor;

    public NaiveBayesAnnotator(FeatureExtractor<? extends FeatureVector, OBJECT> extractor, Mode mode) {
        this.extractor = extractor;
        this.mode = mode;
        this.reset();
    }

    public static <OBJECT extends FeatureVector, ANNOTATION> NaiveBayesAnnotator<OBJECT, ANNOTATION> create(Mode mode) {
        return new NaiveBayesAnnotator<OBJECT, ANNOTATION>(new IdentityFeatureExtractor(), mode);
    }

    @Override
    public void train(Annotated<OBJECT, ANNOTATION> annotated) {
        FeatureVector feature = (FeatureVector)this.extractor.extractFeature(annotated.getObject());
        Vector vec = VectorFactory.getDefault().copyArray(feature.asDoubleVector());
        for (ANNOTATION ann : annotated.getAnnotations()) {
            this.learner.update(this.categorizer, (InputOutputPair)new DefaultInputOutputPair((Object)vec, ann));
        }
    }

    @Override
    public void reset() {
        this.learner = new VectorNaiveBayesCategorizer.OnlineLearner();
        this.learner.setDistributionLearner((IncrementalLearner)new PDFLearner());
        this.categorizer = this.learner.createInitialLearnedObject();
    }

    @Override
    public Set<ANNOTATION> getAnnotations() {
        return this.categorizer.getCategories();
    }

    @Override
    public List<ScoredAnnotation<ANNOTATION>> annotate(OBJECT object) {
        FeatureVector feature = (FeatureVector)this.extractor.extractFeature(object);
        Vector vec = VectorFactory.getDefault().copyArray(feature.asDoubleVector());
        return this.mode.getAnnotations(this.categorizer, vec);
    }

    public static enum Mode {
        ALL{

            @Override
            protected <ANNOTATION> List<ScoredAnnotation<ANNOTATION>> getAnnotations(VectorNaiveBayesCategorizer<ANNOTATION, PDF> categorizer, Vector vec) {
                ArrayList<ScoredAnnotation<ANNOTATION>> results = new ArrayList<ScoredAnnotation<ANNOTATION>>();
                double logDenominator = Double.NEGATIVE_INFINITY;
                for (Object e : categorizer.getCategories()) {
                    double logPosterior = categorizer.computeLogPosterior(vec, e);
                    logDenominator = LogMath.add((double)logDenominator, (double)logPosterior);
                    results.add(new ScoredAnnotation(e, (float)logPosterior));
                }
                for (ScoredAnnotation scoredAnnotation : results) {
                    scoredAnnotation.confidence = (float)Math.exp((double)scoredAnnotation.confidence - logDenominator);
                }
                Collections.sort(results, Collections.reverseOrder());
                return results;
            }
        }
        ,
        MAXIMUM_LIKELIHOOD{

            @Override
            protected <ANNOTATION> List<ScoredAnnotation<ANNOTATION>> getAnnotations(VectorNaiveBayesCategorizer<ANNOTATION, PDF> categorizer, Vector vec) {
                ArrayList<ScoredAnnotation<ANNOTATION>> results = new ArrayList<ScoredAnnotation<ANNOTATION>>();
                DefaultWeightedValueDiscriminant r = categorizer.evaluateWithDiscriminant((Vectorizable)vec);
                results.add(new ScoredAnnotation<Object>(r.getValue(), (float)Math.exp(r.getWeight())));
                return results;
            }
        };


        protected abstract <ANNOTATION> List<ScoredAnnotation<ANNOTATION>> getAnnotations(VectorNaiveBayesCategorizer<ANNOTATION, PDF> var1, Vector var2);
    }

    private static class PDFLearner
    extends AbstractCloneableSerializable
    implements IncrementalLearner<Double, PDF> {
        private static final long serialVersionUID = 1L;
        final UnivariateGaussian.IncrementalEstimator distrLearner = new UnivariateGaussian.IncrementalEstimator();

        private PDFLearner() {
        }

        public PDF createInitialLearnedObject() {
            PDF pdf = new PDF();
            pdf.target = this.distrLearner.createInitialLearnedObject();
            return pdf;
        }

        public void update(PDF pdf, Double data) {
            this.distrLearner.update((SufficientStatistic)pdf.target, (Object)data);
            pdf.setMean(pdf.target.getMean());
            pdf.setVariance(pdf.target.getVariance());
        }

        public void update(PDF pdf, Iterable<? extends Double> data) {
            this.distrLearner.update((Object)pdf.target, data);
            pdf.setMean(pdf.target.getMean());
            pdf.setVariance(pdf.target.getVariance());
        }
    }

    private static class PDF
    extends UnivariateGaussian.PDF {
        private static final long serialVersionUID = 1L;
        private UnivariateGaussian.SufficientStatistic target;

        private PDF() {
        }
    }
}

