/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.hdf5;
import org.deeplearning4j.berkeley.StringUtils;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasModel;
import org.deeplearning4j.nn.modelimport.keras.KerasSequentialModel;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.core.type.TypeReference;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasModelImport {
    private static final Logger log = LoggerFactory.getLogger(KerasModelImport.class);
    private String modelJson;
    private String trainingJson;
    private String modelClassName;
    private Map<String, Map<String, INDArray>> weights;

    public static ComputationGraph importKerasModelAndWeights(InputStream modelHdf5Stream) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport archive = new KerasModelImport(modelHdf5Stream);
        if (!archive.getModelClassName().equals("Model")) {
            throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + archive.getModelClassName() + ")");
        }
        KerasModel kerasModel = new KerasModel.ModelBuilder().modelJson(archive.getModelJson()).trainingJson(archive.getTrainingJson()).weights(archive.getWeights()).train(false).buildModel();
        ComputationGraph model = kerasModel.getComputationGraph();
        return model;
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(InputStream modelHdf5Stream) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport archive = new KerasModelImport(modelHdf5Stream);
        if (!archive.getModelClassName().equals("Model")) {
            throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + archive.getModelClassName() + ")");
        }
        KerasSequentialModel kerasModel = new KerasModel.ModelBuilder().modelJson(archive.getModelJson()).trainingJson(archive.getTrainingJson()).weights(archive.getWeights()).train(false).buildSequential();
        MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();
        return model;
    }

    public static ComputationGraph importKerasModelAndWeights(String modelHdf5Filename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport archive = new KerasModelImport(modelHdf5Filename);
        if (!archive.getModelClassName().equals("Model")) {
            throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + archive.getModelClassName() + ")");
        }
        KerasModel kerasModel = new KerasModel.ModelBuilder().modelJson(archive.getModelJson()).trainingJson(archive.getTrainingJson()).weights(archive.getWeights()).train(false).buildModel();
        ComputationGraph model = kerasModel.getComputationGraph();
        return model;
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelHdf5Filename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport archive = new KerasModelImport(modelHdf5Filename);
        if (!archive.getModelClassName().equals("Sequential")) {
            throw new InvalidKerasConfigurationException("Expected Keras model class name Sequential (found " + archive.getModelClassName() + ")");
        }
        KerasSequentialModel kerasModel = new KerasModel.ModelBuilder().modelJson(archive.getModelJson()).trainingJson(archive.getTrainingJson()).weights(archive.getWeights()).train(false).buildSequential();
        MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();
        return model;
    }

    public static ComputationGraph importKerasModelAndWeights(String modelJsonFilename, String weightsHdf5Filename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport archive = new KerasModelImport(modelJsonFilename, weightsHdf5Filename);
        if (!archive.getModelClassName().equals("Sequential")) {
            throw new InvalidKerasConfigurationException("Expected Keras model class name Model (found " + archive.getModelClassName() + ")");
        }
        KerasModel kerasModel = new KerasModel.ModelBuilder().modelJson(archive.getModelJson()).weights(archive.getWeights()).train(false).buildModel();
        ComputationGraph model = kerasModel.getComputationGraph();
        return model;
    }

    public static MultiLayerNetwork importKerasSequentialModelAndWeights(String modelJsonFilename, String weightsHdf5Filename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        KerasModelImport archive = new KerasModelImport(modelJsonFilename, weightsHdf5Filename);
        if (!archive.getModelClassName().equals("Sequential")) {
            throw new InvalidKerasConfigurationException("Expected Keras model class name Sequential (found " + archive.getModelClassName() + ")");
        }
        KerasSequentialModel kerasModel = new KerasModel.ModelBuilder().modelJson(archive.getModelJson()).trainingJson(archive.getTrainingJson()).weights(archive.getWeights()).train(false).buildSequential();
        MultiLayerNetwork model = kerasModel.getMultiLayerNetwork();
        return model;
    }

    public static ComputationGraphConfiguration importKerasModelConfiguration(String modelJsonFilename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        String modelJson = new String(Files.readAllBytes(Paths.get(modelJsonFilename, new String[0])));
        KerasModel kerasModel = new KerasModel.ModelBuilder().modelJson(modelJson).train(false).buildModel();
        return kerasModel.getComputationGraphConfiguration();
    }

    public static MultiLayerConfiguration importKerasSequentialConfiguration(String modelJsonFilename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        String modelJson = new String(Files.readAllBytes(Paths.get(modelJsonFilename, new String[0])));
        KerasSequentialModel kerasModel = new KerasModel.ModelBuilder().modelJson(modelJson).train(false).buildSequential();
        return kerasModel.getMultiLayerConfiguration();
    }

    public KerasModelImport(InputStream modelHdf5Stream) throws UnsupportedOperationException, IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        log.warn("Importing a Keras model from an InputStream pointing to contents of an HDF5 archive currently not supported.");
        throw new UnsupportedOperationException("Importing a Keras model from an InputStream currently not supported because it is not possible to load an HDF5 file from a memory buffer using the HDF5 C++ API. See: http://stackoverflow.com/questions/18449972/how-can-i-open-hdf5-file-from-memory-buffer-using-hdf5-c-api");
    }

    public KerasModelImport(String modelHdf5Filename) throws IOException, UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        hdf5.H5File file = new hdf5.H5File(modelHdf5Filename, hdf5.H5F_ACC_RDONLY);
        this.modelJson = KerasModelImport.readJsonStringFromHdf5Attribute(file, "model_config");
        this.modelClassName = KerasModelImport.getModelClassName(this.modelJson);
        this.trainingJson = KerasModelImport.readJsonStringFromHdf5Attribute(file, "training_config");
        this.weights = KerasModelImport.readWeightsFromHdf5(file, "/model_weights");
        file.close();
    }

    public KerasModelImport(String modelJsonFilename, String weightsHdf5Filename) throws IOException, InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.modelJson = new String(Files.readAllBytes(Paths.get(modelJsonFilename, new String[0])));
        this.modelClassName = KerasModelImport.getModelClassName(this.modelJson);
        hdf5.H5File file = new hdf5.H5File(weightsHdf5Filename, hdf5.H5F_ACC_RDONLY);
        this.weights = KerasModelImport.readWeightsFromHdf5(file, "/");
        file.close();
    }

    public String getModelJson() {
        return this.modelJson;
    }

    public String getTrainingJson() {
        return this.trainingJson;
    }

    public String getModelClassName() {
        return this.modelClassName;
    }

    public Map<String, Map<String, INDArray>> getWeights() {
        return this.weights;
    }

    private static Map<String, Map<String, INDArray>> readWeightsFromHdf5(hdf5.H5File file, String weightsGroupName) throws UnsupportedKerasConfigurationException {
        hdf5.Group weightsGroup = file.asCommonFG().openGroup(weightsGroupName);
        HashMap<String, Map<String, INDArray>> weightsMap = new HashMap<String, Map<String, INDArray>>();
        ArrayList<hdf5.Group> groups = new ArrayList<hdf5.Group>();
        groups.add(weightsGroup);
        while (!groups.isEmpty()) {
            hdf5.Group g = (hdf5.Group)groups.remove(0);
            int i = 0;
            while ((long)i < g.asCommonFG().getNumObjs()) {
                BytePointer objPtr = g.asCommonFG().getObjnameByIdx((long)i);
                String objName = objPtr.getString();
                int objType = g.asCommonFG().childObjType(objPtr);
                switch (objType) {
                    case 1: {
                        String[] tokens = objName.split("_");
                        String layerName = StringUtils.join((Object[])Arrays.copyOfRange(tokens, 0, 2), (String)"_");
                        String paramName = StringUtils.join((Object[])Arrays.copyOfRange(tokens, 2, tokens.length), (String)"_");
                        Pattern p = Pattern.compile(":\\d+$");
                        Matcher m = p.matcher(paramName);
                        hdf5.DataSet d = g.asCommonFG().openDataSet(objPtr);
                        hdf5.DataSpace space = d.getSpace();
                        int nbDims = space.getSimpleExtentNdims();
                        long[] dims = new long[nbDims];
                        space.getSimpleExtentDims(dims);
                        float[] weightBuffer = null;
                        FloatPointer fp = null;
                        int j = 0;
                        INDArray weights = null;
                        if (m.find()) {
                            paramName = m.replaceFirst("");
                        }
                        switch (nbDims) {
                            case 4: {
                                int i2;
                                weightBuffer = new float[(int)(dims[0] * dims[1] * dims[2] * dims[3])];
                                fp = new FloatPointer(weightBuffer);
                                d.read((Pointer)fp, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                fp.get(weightBuffer);
                                weights = Nd4j.create((int[])new int[]{(int)dims[0], (int)dims[1], (int)dims[2], (int)dims[3]});
                                j = 0;
                                int i1 = 0;
                                while ((long)i1 < dims[0]) {
                                    i2 = 0;
                                    while ((long)i2 < dims[1]) {
                                        int i3 = 0;
                                        while ((long)i3 < dims[2]) {
                                            int i4 = 0;
                                            while ((long)i4 < dims[3]) {
                                                weights.putScalar(i1, i2, i3, i4, (double)weightBuffer[j++]);
                                                ++i4;
                                            }
                                            ++i3;
                                        }
                                        ++i2;
                                    }
                                    ++i1;
                                }
                                break;
                            }
                            case 2: {
                                int i2;
                                weightBuffer = new float[(int)(dims[0] * dims[1])];
                                fp = new FloatPointer(weightBuffer);
                                d.read((Pointer)fp, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                fp.get(weightBuffer);
                                weights = Nd4j.create((int)((int)dims[0]), (int)((int)dims[1]));
                                j = 0;
                                int i1 = 0;
                                while ((long)i1 < dims[0]) {
                                    i2 = 0;
                                    while ((long)i2 < dims[1]) {
                                        weights.putScalar(i1, i2, (double)weightBuffer[j++]);
                                        ++i2;
                                    }
                                    ++i1;
                                }
                                break;
                            }
                            case 1: {
                                weightBuffer = new float[(int)dims[0]];
                                fp = new FloatPointer(weightBuffer);
                                d.read((Pointer)fp, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                fp.get(weightBuffer);
                                weights = Nd4j.create((int)((int)dims[0]));
                                j = 0;
                                int i1 = 0;
                                while ((long)i1 < dims[0]) {
                                    weights.putScalar(i1, weightBuffer[j++]);
                                    ++i1;
                                }
                                break;
                            }
                            default: {
                                throw new UnsupportedKerasConfigurationException("Cannot import weights with rank " + nbDims);
                            }
                        }
                        if (!weightsMap.containsKey(layerName)) {
                            weightsMap.put(layerName, new HashMap());
                        }
                        ((Map)weightsMap.get(layerName)).put(paramName, weights);
                        d.close();
                        break;
                    }
                    default: {
                        groups.add(g.asCommonFG().openGroup(objPtr));
                    }
                }
                ++i;
            }
            g.close();
        }
        file.close();
        return weightsMap;
    }

    private static String readJsonStringFromHdf5Attribute(hdf5.H5File file, String attribute) throws InvalidKerasConfigurationException {
        hdf5.Attribute attr = file.openAttribute(attribute);
        hdf5.VarLenType vl = attr.getVarLenType();
        int bufferSizeMult = 1;
        String jsonString = null;
        while (true) {
            byte[] attrBuffer = new byte[bufferSizeMult * 2000];
            BytePointer attrPointer = new BytePointer(attrBuffer);
            attr.read((hdf5.DataType)vl, attrPointer);
            attrPointer.get(attrBuffer);
            jsonString = new String(attrBuffer);
            ObjectMapper mapper = new ObjectMapper();
            mapper.enable(DeserializationFeature.FAIL_ON_READING_DUP_TREE_KEY);
            try {
                mapper.readTree(jsonString);
            }
            catch (IOException iOException) {
                if (++bufferSizeMult <= 100) continue;
                throw new InvalidKerasConfigurationException("Could not read abnormally long Keras config. Please file an issue!");
            }
            break;
        }
        return jsonString;
    }

    private static String getModelClassName(String modelJson) throws IOException, InvalidKerasConfigurationException {
        ObjectMapper mapper = new ObjectMapper();
        TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>(){};
        Map modelConfig = (Map)mapper.readValue(modelJson, (TypeReference)typeRef);
        if (!modelConfig.containsKey("class_name")) {
            throw new InvalidKerasConfigurationException("Unable to determine Keras model class name.");
        }
        return (String)modelConfig.get("class_name");
    }

    static {
        try {
            Loader.load(hdf5.class);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

