/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.hdf5;
import org.deeplearning4j.berkeley.StringUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.IncompatibleKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.ModelConfiguration;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.databind.DeserializationFeature;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Model {
    private static Logger log = LoggerFactory.getLogger(Model.class);

    private Model() {
    }

    public static MultiLayerNetwork importSequentialModel(String modelHdf5Filename) throws IOException, IncompatibleKerasConfigurationException, UnsupportedOperationException {
        MultiLayerNetwork model = (MultiLayerNetwork)Model.importModel(modelHdf5Filename, true);
        return model;
    }

    public static ComputationGraph importFunctionalApiModel(String modelHdf5Filename) throws IOException, IncompatibleKerasConfigurationException, UnsupportedOperationException {
        ComputationGraph model = (ComputationGraph)Model.importModel(modelHdf5Filename, false);
        return model;
    }

    private static <T> T importModel(String modelHdf5Filename, boolean isSequential) throws IOException {
        hdf5.H5File file = new hdf5.H5File(modelHdf5Filename, hdf5.H5F_ACC_RDONLY);
        hdf5.Attribute attr = file.openAttribute("model_config");
        hdf5.VarLenType vl = attr.getVarLenType();
        int bufferSizeMult = 1;
        String configJson = null;
        while (true) {
            byte[] attrBuffer = new byte[bufferSizeMult * 2000];
            BytePointer attrPointer = new BytePointer(attrBuffer);
            attr.read((hdf5.DataType)vl, attrPointer);
            attrPointer.get(attrBuffer);
            configJson = new String(attrBuffer);
            ObjectMapper mapper = new ObjectMapper();
            mapper.enable(DeserializationFeature.FAIL_ON_READING_DUP_TREE_KEY);
            try {
                mapper.readTree(configJson);
            }
            catch (IOException iOException) {
                if (++bufferSizeMult <= 100) continue;
                throw new IncompatibleKerasConfigurationException("Could not read abnormally long Keras config. Please file an issue!");
            }
            break;
        }
        T model = Model.importModel(configJson, file.asCommonFG().openGroup("/model_weights"), isSequential);
        file.close();
        return model;
    }

    public static MultiLayerNetwork importSequentialModel(String configJsonFilename, String weightsHdf5Filename) throws IOException, IncompatibleKerasConfigurationException, UnsupportedOperationException {
        MultiLayerNetwork model = (MultiLayerNetwork)Model.importModel(configJsonFilename, weightsHdf5Filename, true);
        return model;
    }

    public static ComputationGraph importModel(String configJsonFilename, String weightsHdf5Filename) throws IOException, IncompatibleKerasConfigurationException, UnsupportedOperationException {
        ComputationGraph model = (ComputationGraph)Model.importModel(configJsonFilename, weightsHdf5Filename, false);
        return model;
    }

    private static <T> T importModel(String configJsonFilename, String weightsHdf5Filename, boolean isSequential) throws IOException, IncompatibleKerasConfigurationException, UnsupportedOperationException {
        String configJson = new String(Files.readAllBytes(Paths.get(configJsonFilename, new String[0])));
        hdf5.H5File file = new hdf5.H5File();
        file.openFile(weightsHdf5Filename, hdf5.H5F_ACC_RDONLY);
        T model = Model.importModel(configJson, file.asCommonFG().openGroup("/"), isSequential);
        file.close();
        return model;
    }

    private static <T> T importModel(String configJson, hdf5.Group weightsGroup, boolean isSequential) throws IOException, UnsupportedOperationException {
        MultiLayerNetwork model = null;
        if (!isSequential) {
            ComputationGraphConfiguration config = ModelConfiguration.importFunctionalApiConfig(configJson);
            ComputationGraph cg = new ComputationGraph(config);
            cg.init();
            model = cg;
            throw new UnsupportedOperationException("Keras Functional API models not supported.");
        }
        MultiLayerConfiguration config = ModelConfiguration.importSequentialModelConfig(configJson);
        MultiLayerNetwork mln = new MultiLayerNetwork(config);
        mln.init();
        model = mln;
        Map<String, Object> weightsMetadata = ModelConfiguration.extractWeightsMetadataFromConfig(configJson);
        Map<String, Map<String, INDArray>> weights = Model.readWeightsFromHdf5(weightsGroup);
        Model.importWeights(model, weights, weightsMetadata, isSequential);
        return (T)model;
    }

    private static Map<String, Map<String, INDArray>> readWeightsFromHdf5(hdf5.Group weightsGroup) {
        HashMap<String, Map<String, INDArray>> weightsMap = new HashMap<String, Map<String, INDArray>>();
        ArrayList<hdf5.Group> groups = new ArrayList<hdf5.Group>();
        groups.add(weightsGroup);
        while (!groups.isEmpty()) {
            hdf5.Group g = (hdf5.Group)groups.remove(0);
            String groupName = g.getObjName().getString();
            int i = 0;
            while ((long)i < g.asCommonFG().getNumObjs()) {
                BytePointer objPtr = g.asCommonFG().getObjnameByIdx((long)i);
                String objName = objPtr.getString();
                int objType = g.asCommonFG().childObjType(objPtr);
                switch (objType) {
                    case 1: {
                        hdf5.DataSet d = g.asCommonFG().openDataSet(objPtr);
                        hdf5.DataSpace space = d.getSpace();
                        int nbDims = space.getSimpleExtentNdims();
                        long[] dims = new long[nbDims];
                        space.getSimpleExtentDims(dims);
                        float[] weightBuffer = null;
                        FloatPointer fp = null;
                        int j = 0;
                        INDArray weights = null;
                        switch (nbDims) {
                            case 4: {
                                weightBuffer = new float[(int)(dims[0] * dims[1] * dims[2] * dims[3])];
                                fp = new FloatPointer(weightBuffer);
                                d.read((Pointer)fp, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                fp.get(weightBuffer);
                                weights = Nd4j.create((int[])new int[]{(int)dims[0], (int)dims[1], (int)dims[2], (int)dims[3]});
                                j = 0;
                                int i1 = 0;
                                while ((long)i1 < dims[0]) {
                                    int i2 = 0;
                                    while ((long)i2 < dims[1]) {
                                        int i3 = 0;
                                        while ((long)i3 < dims[2]) {
                                            int i4 = 0;
                                            while ((long)i4 < dims[3]) {
                                                weights.putScalar(i1, i2, i3, i4, (double)weightBuffer[j++]);
                                                ++i4;
                                            }
                                            ++i3;
                                        }
                                        ++i2;
                                    }
                                    ++i1;
                                }
                                break;
                            }
                            case 2: {
                                weightBuffer = new float[(int)(dims[0] * dims[1])];
                                fp = new FloatPointer(weightBuffer);
                                d.read((Pointer)fp, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                fp.get(weightBuffer);
                                weights = Nd4j.create((int)((int)dims[0]), (int)((int)dims[1]));
                                j = 0;
                                int i1 = 0;
                                while ((long)i1 < dims[0]) {
                                    int i2 = 0;
                                    while ((long)i2 < dims[1]) {
                                        weights.putScalar(i1, i2, (double)weightBuffer[j++]);
                                        ++i2;
                                    }
                                    ++i1;
                                }
                                break;
                            }
                            case 1: {
                                weightBuffer = new float[(int)dims[0]];
                                fp = new FloatPointer(weightBuffer);
                                d.read((Pointer)fp, new hdf5.DataType(hdf5.PredType.NATIVE_FLOAT()));
                                fp.get(weightBuffer);
                                weights = Nd4j.create((int)((int)dims[0]));
                                j = 0;
                                int i1 = 0;
                                while ((long)i1 < dims[0]) {
                                    weights.putScalar(i1, weightBuffer[j++]);
                                    ++i1;
                                }
                                break;
                            }
                            default: {
                                throw new IncompatibleKerasConfigurationException("Cannot import weights with rank " + nbDims);
                            }
                        }
                        String[] tokens = objName.split("_");
                        String layerName = StringUtils.join((Object[])Arrays.copyOfRange(tokens, 0, tokens.length - 1), (String)"_");
                        String paramName = tokens[tokens.length - 1];
                        if (!weightsMap.containsKey(layerName)) {
                            weightsMap.put(layerName, new HashMap());
                        }
                        ((Map)weightsMap.get(layerName)).put(paramName, weights);
                        d.close();
                        break;
                    }
                    default: {
                        groups.add(g.asCommonFG().openGroup(objPtr));
                    }
                }
                ++i;
            }
            g.close();
        }
        return weightsMap;
    }

    private static <T> T importWeights(T model, Map<String, Map<String, INDArray>> weights, Map<String, Object> weightsMetadata, boolean isSequential) throws IncompatibleKerasConfigurationException {
        String kerasBackend = weightsMetadata.containsKey("keras_backend") ? (String)weightsMetadata.get("keras_backend") : "none";
        for (String layerName : weights.keySet()) {
            Layer layer = null;
            layer = isSequential ? ((MultiLayerNetwork)model).getLayer(layerName) : ((ComputationGraph)model).getLayer(layerName);
            for (String kerasParamName : weights.get(layerName).keySet()) {
                String paramName = null;
                Pattern p = Pattern.compile(":\\d+$");
                Matcher m = p.matcher(kerasParamName);
                paramName = m.find() ? m.replaceFirst("") : kerasParamName;
                INDArray W = weights.get(layerName).get(kerasParamName);
                if (layer instanceof ConvolutionLayer && paramName.equals("W")) {
                    if (kerasBackend.equals("tf")) {
                        W = W.permute(new int[]{3, 2, 0, 1});
                    } else if (kerasBackend.equals("th")) {
                        W = W.permute(new int[]{3, 0, 1, 2});
                    } else {
                        throw new IncompatibleKerasConfigurationException("Unknown keras backend " + kerasBackend);
                    }
                    layer.setParam(paramName, W);
                }
                layer.setParam(paramName, W);
            }
        }
        return model;
    }
}

