/*
 * Decompiled with CFR 0.152.
 */
package org.apache.tika.dl.imagerec;

import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.tika.config.Field;
import org.apache.tika.config.InitializableProblemHandler;
import org.apache.tika.config.Param;
import org.apache.tika.exception.TikaConfigException;
import org.apache.tika.exception.TikaException;
import org.apache.tika.metadata.Metadata;
import org.apache.tika.mime.MediaType;
import org.apache.tika.parser.ParseContext;
import org.apache.tika.parser.recognition.ObjectRecogniser;
import org.apache.tika.parser.recognition.RecognisedObject;
import org.datavec.image.loader.NativeImageLoader;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.util.ModelSerializer;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.model.VGG16;
import org.deeplearning4j.zoo.util.imagenet.ImageNetLabels;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.VGG16ImagePreProcessor;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.xml.sax.ContentHandler;
import org.xml.sax.SAXException;

public class DL4JVGG16Net
implements ObjectRecogniser {
    public static final Set<MediaType> SUPPORTED_MIMES = Collections.singleton(MediaType.image((String)"jpeg"));
    private static final Logger LOG = LoggerFactory.getLogger(DL4JVGG16Net.class);
    private static final String BASE_DIR = System.getProperty("user.home") + File.separator + ".tika-dl" + File.separator + "models" + File.separator + "dl4j";
    private static final String MODEL_DIR = BASE_DIR + File.separator + "vgg-16";
    @Field
    private File cacheDir = new File(MODEL_DIR + File.separator + "vgg16.zip");
    @Field
    private boolean serialize = true;
    @Field
    private int topN;
    private NativeImageLoader imageLoader = new NativeImageLoader(224L, 224L, 3L);
    private DataNormalization preProcessor = new VGG16ImagePreProcessor();
    private boolean available = false;
    private ComputationGraph model;
    private ImageNetLabels imageNetLabels;

    public Set<MediaType> getSupportedMimes() {
        return SUPPORTED_MIMES;
    }

    public boolean isAvailable() {
        return this.available;
    }

    public void checkInitialization(InitializableProblemHandler problemHandler) throws TikaConfigException {
    }

    public void initialize(Map<String, Param> params) throws TikaConfigException {
        try {
            if (this.serialize) {
                if (this.cacheDir.exists()) {
                    this.model = ModelSerializer.restoreComputationGraph((File)this.cacheDir);
                    LOG.info("Preprocessed Model Loaded from {}", (Object)this.cacheDir);
                } else {
                    LOG.warn("Preprocessed Model doesn't exist at {}", (Object)this.cacheDir);
                    this.cacheDir.getParentFile().mkdirs();
                    VGG16 zooModel = VGG16.builder().build();
                    this.model = (ComputationGraph)zooModel.initPretrained(PretrainedType.IMAGENET);
                    LOG.info("Saving the Loaded model for future use. Saved models are more optimised to consume less resources.");
                    ModelSerializer.writeModel((Model)this.model, (File)this.cacheDir, (boolean)true);
                }
            } else {
                LOG.info("Weight graph model loaded via dl4j Helper functions");
                VGG16 zooModel = VGG16.builder().build();
                this.model = (ComputationGraph)zooModel.initPretrained(PretrainedType.IMAGENET);
            }
            this.imageNetLabels = new ImageNetLabels();
            this.available = true;
        }
        catch (Exception e) {
            this.available = false;
            LOG.warn(e.getMessage(), (Throwable)e);
            throw new TikaConfigException(e.getMessage(), (Throwable)e);
        }
    }

    public List<RecognisedObject> recognise(InputStream stream, ContentHandler handler, Metadata metadata, ParseContext context) throws IOException, SAXException, TikaException {
        INDArray image = this.imageLoader.asMatrix(stream);
        this.preProcessor.transform(image);
        INDArray[] output = this.model.output(false, new INDArray[]{image});
        return this.predict(output[0]);
    }

    private List<RecognisedObject> predict(INDArray predictions) {
        ArrayList<RecognisedObject> objects = new ArrayList<RecognisedObject>();
        int[] topNPredictions = new int[this.topN];
        float[] topNProb = new float[this.topN];
        String[] outLabels = new String[this.topN];
        int i = 0;
        int batch = 0;
        while ((long)batch < predictions.size(0)) {
            INDArray currentBatch = predictions.getRow((long)batch).dup();
            while (i < this.topN) {
                topNPredictions[i] = Nd4j.argMax((INDArray)currentBatch, (int[])new int[]{1}).getInt(new int[]{0});
                topNProb[i] = currentBatch.getFloat((long)batch, (long)topNPredictions[i]);
                currentBatch.putScalar(0L, (long)topNPredictions[i], 0.0);
                outLabels[i] = this.imageNetLabels.getLabel(topNPredictions[i]);
                objects.add(new RecognisedObject(outLabels[i], "eng", outLabels[i], (double)topNProb[i]));
                ++i;
            }
            ++batch;
        }
        return objects;
    }
}

