/*
 * Decompiled with CFR 0.152.
 */
package com.facebook.appevents.ml;

import android.content.Context;
import android.os.AsyncTask;
import android.support.annotation.Nullable;
import com.facebook.FacebookSdk;
import com.facebook.appevents.ml.Operator;
import com.facebook.appevents.ml.Utils;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.InputStream;
import java.net.URL;
import java.net.URLConnection;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.json.JSONArray;
import org.json.JSONObject;

final class Model {
    private static final String DIR_NAME = "facebook_ml/";
    private static final List<String> SUGGESTED_EVENTS_PREDICTION = Arrays.asList("fb_mobile_add_to_cart", "fb_mobile_complete_registration", "other", "fb_mobile_purchase");
    private String useCase;
    private File modelFile;
    private File ruleFile;
    private int versionID;
    private float[] thresholds;
    @Nullable
    private String modelUri;
    @Nullable
    private String ruleUri;
    @Nullable
    private static Weight embedding;
    @Nullable
    private static Weight convs_1_weight;
    @Nullable
    private static Weight convs_2_weight;
    @Nullable
    private static Weight convs_3_weight;
    @Nullable
    private static Weight convs_1_bias;
    @Nullable
    private static Weight convs_2_bias;
    @Nullable
    private static Weight convs_3_bias;
    @Nullable
    private static Weight fc1_weight;
    @Nullable
    private static Weight fc2_weight;
    @Nullable
    private static Weight fc3_weight;
    @Nullable
    private static Weight fc1_bias;
    @Nullable
    private static Weight fc2_bias;
    @Nullable
    private static Weight fc3_bias;
    private final int SEQ_LEN = 128;
    private final int EMBEDDING_SIZE = 64;

    Model(String useCase, int versionID, String modelUri, @Nullable String ruleUri, float[] thresholds) {
        this.useCase = useCase;
        this.versionID = versionID;
        this.thresholds = thresholds;
        this.modelUri = modelUri;
        this.ruleUri = ruleUri;
        String modelFilePath = DIR_NAME + useCase + "_" + versionID;
        String ruleFilePath = DIR_NAME + useCase + "_" + versionID + "_rule";
        File dir = FacebookSdk.getApplicationContext().getFilesDir();
        this.modelFile = new File(dir, modelFilePath);
        this.ruleFile = new File(dir, ruleFilePath);
    }

    void initialize(final Runnable onModelInitialized) {
        this.downloadModel(new Runnable(){

            @Override
            public void run() {
                if (Model.this.initializeWeights()) {
                    Model.this.downloadRule(onModelInitialized);
                }
            }
        });
    }

    @Nullable
    File getRuleFile() {
        return this.ruleFile;
    }

    private void downloadModel(Runnable onDownloaded) {
        if (this.modelFile.exists()) {
            onDownloaded.run();
            return;
        }
        if (this.modelUri != null) {
            new FileDownloadTask(this.modelUri, this.modelFile, onDownloaded).execute(new String[0]);
        }
    }

    private void downloadRule(Runnable onDownloaded) {
        if (this.ruleFile.exists() || this.ruleUri == null) {
            onDownloaded.run();
            return;
        }
        new FileDownloadTask(this.ruleUri, this.ruleFile, onDownloaded).execute(new String[0]);
    }

    private boolean initializeWeights() {
        try {
            FileInputStream inputStream = new FileInputStream(this.modelFile);
            int length = ((InputStream)inputStream).available();
            DataInputStream dataIs = new DataInputStream(inputStream);
            byte[] allData = new byte[length];
            dataIs.readFully(allData);
            dataIs.close();
            if (length < 4) {
                return false;
            }
            ByteBuffer bb = ByteBuffer.wrap(allData, 0, 4);
            bb.order(ByteOrder.LITTLE_ENDIAN);
            int jsonLen = bb.getInt();
            if (length < jsonLen + 4) {
                return false;
            }
            String jsonStr = new String(allData, 4, jsonLen);
            JSONObject info = new JSONObject(jsonStr);
            JSONArray names = info.names();
            Object[] keys = new String[names.length()];
            for (int i = 0; i < keys.length; ++i) {
                keys[i] = names.getString(i);
            }
            Arrays.sort(keys);
            int offset = 4 + jsonLen;
            HashMap<Object, Weight> weights = new HashMap<Object, Weight>();
            for (Object key : keys) {
                int count = 1;
                JSONArray shapes = info.getJSONArray((String)key);
                int[] shape = new int[shapes.length()];
                for (int i = 0; i < shape.length; ++i) {
                    shape[i] = shapes.getInt(i);
                    count *= shape[i];
                }
                if (offset + count * 4 > length) {
                    return false;
                }
                bb = ByteBuffer.wrap(allData, offset, count * 4);
                bb.order(ByteOrder.LITTLE_ENDIAN);
                float[] data = new float[count];
                bb.asFloatBuffer().get(data, 0, count);
                weights.put(key, new Weight(shape, data));
                offset += count * 4;
            }
            embedding = (Weight)weights.get("embed.weight");
            convs_1_weight = (Weight)weights.get("convs.0.weight");
            convs_2_weight = (Weight)weights.get("convs.1.weight");
            convs_3_weight = (Weight)weights.get("convs.2.weight");
            Model.convs_1_weight.data = Operator.transpose3D(Model.convs_1_weight.data, Model.convs_1_weight.shape[0], Model.convs_1_weight.shape[1], Model.convs_1_weight.shape[2]);
            Model.convs_2_weight.data = Operator.transpose3D(Model.convs_2_weight.data, Model.convs_2_weight.shape[0], Model.convs_2_weight.shape[1], Model.convs_2_weight.shape[2]);
            Model.convs_3_weight.data = Operator.transpose3D(Model.convs_3_weight.data, Model.convs_3_weight.shape[0], Model.convs_3_weight.shape[1], Model.convs_3_weight.shape[2]);
            convs_1_bias = (Weight)weights.get("convs.0.bias");
            convs_2_bias = (Weight)weights.get("convs.1.bias");
            convs_3_bias = (Weight)weights.get("convs.2.bias");
            fc1_weight = (Weight)weights.get("fc1.weight");
            fc2_weight = (Weight)weights.get("fc2.weight");
            fc3_weight = (Weight)weights.get("fc3.weight");
            Model.fc1_weight.data = Operator.transpose2D(Model.fc1_weight.data, Model.fc1_weight.shape[0], Model.fc1_weight.shape[1]);
            Model.fc2_weight.data = Operator.transpose2D(Model.fc2_weight.data, Model.fc2_weight.shape[0], Model.fc2_weight.shape[1]);
            Model.fc3_weight.data = Operator.transpose2D(Model.fc3_weight.data, Model.fc3_weight.shape[0], Model.fc3_weight.shape[1]);
            fc1_bias = (Weight)weights.get("fc1.bias");
            fc2_bias = (Weight)weights.get("fc2.bias");
            fc3_bias = (Weight)weights.get("fc3.bias");
            return true;
        }
        catch (Exception e) {
            return false;
        }
    }

    @Nullable
    String predict(float[] dense, String text) {
        int[] x = Utils.vectorize(text, 128);
        float[] embed_x = Operator.embedding(x, Model.embedding.data, 1, 128, 64);
        float[] c1 = Operator.conv1D(embed_x, Model.convs_1_weight.data, 1, 128, 64, Model.convs_1_weight.shape[2], Model.convs_1_weight.shape[0]);
        float[] c2 = Operator.conv1D(embed_x, Model.convs_2_weight.data, 1, 128, 64, Model.convs_2_weight.shape[2], Model.convs_2_weight.shape[0]);
        float[] c3 = Operator.conv1D(embed_x, Model.convs_3_weight.data, 1, 128, 64, Model.convs_3_weight.shape[2], Model.convs_3_weight.shape[0]);
        Operator.add(c1, Model.convs_1_bias.data, 1, 128 - Model.convs_1_weight.shape[2] + 1, Model.convs_1_weight.shape[0]);
        Operator.add(c2, Model.convs_2_bias.data, 1, 128 - Model.convs_2_weight.shape[2] + 1, Model.convs_2_weight.shape[0]);
        Operator.add(c3, Model.convs_3_bias.data, 1, 128 - Model.convs_3_weight.shape[2] + 1, Model.convs_3_weight.shape[0]);
        Operator.relu(c1, (128 - Model.convs_1_weight.shape[2] + 1) * Model.convs_1_weight.shape[0]);
        Operator.relu(c2, (128 - Model.convs_2_weight.shape[2] + 1) * Model.convs_2_weight.shape[0]);
        Operator.relu(c3, (128 - Model.convs_3_weight.shape[2] + 1) * Model.convs_3_weight.shape[0]);
        float[] ca = Operator.maxPool1D(c1, 128 - Model.convs_1_weight.shape[2] + 1, Model.convs_1_weight.shape[0], 128 - Model.convs_1_weight.shape[2] + 1);
        float[] cb = Operator.maxPool1D(c2, 128 - Model.convs_2_weight.shape[2] + 1, Model.convs_2_weight.shape[0], 128 - Model.convs_2_weight.shape[2] + 1);
        float[] cc = Operator.maxPool1D(c3, 128 - Model.convs_3_weight.shape[2] + 1, Model.convs_3_weight.shape[0], 128 - Model.convs_3_weight.shape[2] + 1);
        float[] concat = Operator.concatenate(Operator.concatenate(Operator.concatenate(ca, cb), cc), dense);
        float[] dense1_x = Operator.dense(concat, Model.fc1_weight.data, Model.fc1_bias.data, 1, Model.fc1_weight.shape[1], Model.fc1_weight.shape[0]);
        Operator.relu(dense1_x, Model.fc1_bias.shape[0]);
        float[] dense2_x = Operator.dense(dense1_x, Model.fc2_weight.data, Model.fc2_bias.data, 1, Model.fc2_weight.shape[1], Model.fc2_weight.shape[0]);
        Operator.relu(dense2_x, Model.fc2_bias.shape[0]);
        float[] predictedRaw = Operator.dense(dense2_x, Model.fc3_weight.data, Model.fc3_bias.data, 1, Model.fc3_weight.shape[1], Model.fc3_weight.shape[0]);
        Operator.softmax(predictedRaw, Model.fc3_bias.shape[0]);
        for (int i = 0; i < this.thresholds.length; ++i) {
            if (!(predictedRaw[i] >= this.thresholds[i])) continue;
            return SUGGESTED_EVENTS_PREDICTION.get(i);
        }
        return "other";
    }

    private static class Weight {
        public int[] shape;
        public float[] data;

        Weight(int[] shape, float[] data) {
            this.shape = shape;
            this.data = data;
        }
    }

    static class FileDownloadTask
    extends AsyncTask<String, Void, Boolean> {
        Runnable onSuccess;
        File destFile;
        String uriStr;

        FileDownloadTask(String uriStr, File destFile, Runnable onSuccess) {
            this.uriStr = uriStr;
            this.destFile = destFile;
            this.onSuccess = onSuccess;
        }

        protected Boolean doInBackground(String ... args) {
            try {
                Context context = FacebookSdk.getApplicationContext();
                File dir = new File(context.getFilesDir(), Model.DIR_NAME);
                if (!dir.exists()) {
                    dir.mkdirs();
                }
                URL url = new URL(this.uriStr);
                URLConnection conn = url.openConnection();
                int contentLength = conn.getContentLength();
                DataInputStream stream = new DataInputStream(url.openStream());
                byte[] buffer = new byte[contentLength];
                stream.readFully(buffer);
                stream.close();
                DataOutputStream fos = new DataOutputStream(new FileOutputStream(this.destFile));
                fos.write(buffer);
                fos.flush();
                fos.close();
                return true;
            }
            catch (Exception exception) {
                return false;
            }
        }

        protected void onPostExecute(Boolean isSuccess) {
            if (isSuccess.booleanValue()) {
                this.onSuccess.run();
            }
        }
    }
}

