/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.graph.models.embeddings;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.graph.models.BinaryTree;
import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

public class InMemoryGraphLookupTable
implements GraphVectorLookupTable {
    protected int nVertices;
    protected int vectorSize;
    protected BinaryTree tree;
    protected INDArray vertexVectors;
    protected INDArray outWeights;
    protected double learningRate;
    protected double[] expTable;
    protected static double MAX_EXP = 6.0;

    public InMemoryGraphLookupTable(int nVertices, int vectorSize, BinaryTree tree, double learningRate) {
        this.nVertices = nVertices;
        this.vectorSize = vectorSize;
        this.tree = tree;
        this.learningRate = learningRate;
        this.resetWeights();
        this.expTable = new double[1000];
        for (int i = 0; i < this.expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)this.expTable.length * 2.0 - 1.0) * MAX_EXP));
            this.expTable[i] = tmp / (tmp + 1.0);
        }
    }

    public INDArray getVertexVectors() {
        return this.vertexVectors;
    }

    public INDArray getOutWeights() {
        return this.outWeights;
    }

    @Override
    public int vectorSize() {
        return this.vectorSize;
    }

    @Override
    public void resetWeights() {
        this.vertexVectors = Nd4j.rand((int[])new int[]{this.nVertices, this.vectorSize}).subi((Number)0.5).divi((Number)this.vectorSize);
        this.outWeights = Nd4j.rand((int[])new int[]{this.nVertices - 1, this.vectorSize}).subi((Number)0.5).divi((Number)this.vectorSize);
    }

    @Override
    public void iterate(int first, int second) {
        INDArray[][] vecAndGrads = this.vectorsAndGradients(first, second);
        Level1 l1 = Nd4j.getBlasWrapper().level1();
        for (int i = 0; i < vecAndGrads[0].length; ++i) {
            l1.axpy(vecAndGrads[0][i].length(), -this.learningRate, vecAndGrads[1][i], vecAndGrads[0][i]);
        }
    }

    public INDArray[][] vectorsAndGradients(int first, int second) {
        INDArray vec = this.vertexVectors.getRow((long)first);
        int codeLength = this.tree.getCodeLength(second);
        long code = this.tree.getCode(second);
        int[] innerNodesForVertex = this.tree.getPathInnerNodes(second);
        INDArray[][] out = new INDArray[2][innerNodesForVertex.length + 1];
        Level1 l1 = Nd4j.getBlasWrapper().level1();
        INDArray accumError = Nd4j.create((long[])vec.shape());
        for (int i = 0; i < codeLength; ++i) {
            INDArray innerNodeGrad;
            int innerNodeIdx = innerNodesForVertex[i];
            boolean path = this.getBit(code, i);
            INDArray innerNodeVector = this.outWeights.getRow((long)innerNodeIdx);
            double sigmoidDot = InMemoryGraphLookupTable.sigmoid(Nd4j.getBlasWrapper().dot(innerNodeVector, vec));
            if (path) {
                innerNodeGrad = vec.mul((Number)(sigmoidDot - 1.0));
                l1.axpy(vec.length(), sigmoidDot - 1.0, innerNodeVector, accumError);
            } else {
                innerNodeGrad = vec.mul((Number)sigmoidDot);
                l1.axpy(vec.length(), sigmoidDot, innerNodeVector, accumError);
            }
            out[0][i + 1] = innerNodeVector;
            out[1][i + 1] = innerNodeGrad;
        }
        out[0][0] = vec;
        out[1][0] = accumError;
        return out;
    }

    public double calculateProb(int first, int second) {
        INDArray vec = this.vertexVectors.getRow((long)first);
        int codeLength = this.tree.getCodeLength(second);
        long code = this.tree.getCode(second);
        int[] innerNodesForVertex = this.tree.getPathInnerNodes(second);
        double prob = 1.0;
        for (int i = 0; i < codeLength; ++i) {
            boolean path = this.getBit(code, i);
            int innerNodeIdx = innerNodesForVertex[i];
            INDArray nwi = this.outWeights.getRow((long)innerNodeIdx);
            double dot = Nd4j.getBlasWrapper().dot(nwi, vec);
            double innerProb = path ? InMemoryGraphLookupTable.sigmoid(dot) : InMemoryGraphLookupTable.sigmoid(-dot);
            prob *= innerProb;
        }
        return prob;
    }

    public double calculateScore(int first, int second) {
        double prob = this.calculateProb(first, second);
        return -FastMath.log((double)prob);
    }

    public BinaryTree getTree() {
        return this.tree;
    }

    public INDArray getInnerNodeVector(int innerNode) {
        return this.outWeights.getRow((long)innerNode);
    }

    @Override
    public INDArray getVector(int idx) {
        return this.vertexVectors.getRow((long)idx);
    }

    @Override
    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
    }

    @Override
    public int getNumVertices() {
        return this.nVertices;
    }

    private static double sigmoid(double in) {
        return 1.0 / (1.0 + FastMath.exp((double)(-in)));
    }

    private boolean getBit(long in, int bitNum) {
        long mask = 1L << bitNum;
        return (in & mask) != 0L;
    }

    public void setVertexVectors(INDArray vertexVectors) {
        this.vertexVectors = vertexVectors;
    }
}

