/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.rntn;

import akka.actor.ActorSystem;
import akka.dispatch.Futures;
import akka.dispatch.OnComplete;
import com.google.common.util.concurrent.AtomicDouble;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.CountDownLatch;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.Tree;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.util.MultiDimensionalMap;
import org.deeplearning4j.util.MultiDimensionalSet;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Function1;
import scala.concurrent.ExecutionContext;
import scala.concurrent.Future;

public class RNTN
implements Layer {
    protected NeuralNetConfiguration conf;
    protected Collection<IterationListener> iterationListeners = new ArrayList<IterationListener>();
    protected double value = 0.0;
    private int numOuts = 3;
    private int numHidden = 25;
    private Random rng;
    private boolean useDoubleTensors = true;
    private boolean combineClassification = true;
    private boolean simplifiedModel = true;
    private boolean randomFeatureVectors = true;
    private double scalingForInit = 1.0;
    private boolean lowerCasefeatureNames;
    protected String activationFunction = "tanh";
    protected String outputActivation = "softmax";
    protected AdaGrad paramAdaGrad;
    protected int numParameters = -1;
    private double regTransformMatrix = 0.001f;
    private double regClassification = 1.0E-4f;
    private double regWordVector = 1.0E-4f;
    private int adagradResetFrequency = 1;
    private double regTransformINDArray = 0.001f;
    private MultiDimensionalMap<String, String, INDArray> binaryTransform;
    private MultiDimensionalMap<String, String, INDArray> binaryTensors;
    private Map<String, INDArray> unaryClassification;
    private WeightLookupTable featureVectors;
    private VocabCache vocabCache;
    private MultiDimensionalMap<String, String, INDArray> binaryClassification;
    private int numBinaryMatrices;
    private int binaryTransformSize;
    private int binaryINd4jize;
    private int binaryClassificationSize;
    private int numUnaryMatrices;
    private int unaryClassificationSize;
    private INDArray identity;
    private Map<Integer, Double> classWeights;
    private static final Logger log = LoggerFactory.getLogger(RNTN.class);
    private transient ActorSystem rnTnActorSystem = ActorSystem.create((String)"RNTN");

    private RNTN(int numHidden, Random rng, boolean useDoubleTensors, boolean combineClassification, boolean simplifiedModel, boolean randomFeatureVectors, double scalingForInit, boolean lowerCasefeatureNames, String activationFunction, int adagradResetFrequency, double regTransformINDArray, WeightLookupTable featureVectors, VocabCache vocabCache, int numBinaryMatrices, int binaryTransformSize, int binaryINd4jize, int binaryClassificationSize, int numUnaryMatrices, int unaryClassificationSize, Map<Integer, Double> classWeights) {
        this.vocabCache = vocabCache;
        this.numHidden = numHidden;
        this.rng = rng;
        this.useDoubleTensors = useDoubleTensors;
        this.combineClassification = combineClassification;
        this.simplifiedModel = simplifiedModel;
        this.randomFeatureVectors = randomFeatureVectors;
        this.scalingForInit = scalingForInit;
        this.lowerCasefeatureNames = lowerCasefeatureNames;
        this.activationFunction = activationFunction;
        this.adagradResetFrequency = adagradResetFrequency;
        this.regTransformINDArray = regTransformINDArray;
        this.featureVectors = featureVectors;
        this.numBinaryMatrices = numBinaryMatrices;
        this.binaryTransformSize = binaryTransformSize;
        this.binaryINd4jize = binaryINd4jize;
        this.binaryClassificationSize = binaryClassificationSize;
        this.numUnaryMatrices = numUnaryMatrices;
        this.unaryClassificationSize = unaryClassificationSize;
        this.classWeights = classWeights;
        this.init();
    }

    private void init() {
        if (this.rng == null) {
            this.rng = Nd4j.getRandom();
        }
        MultiDimensionalSet binaryProductions = MultiDimensionalSet.hashSet();
        if (!this.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        binaryProductions.add((Object)"", (Object)"");
        HashSet<String> unaryProductions = new HashSet<String>();
        if (!this.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        unaryProductions.add("");
        this.identity = Nd4j.eye((int)this.numHidden);
        this.binaryTransform = MultiDimensionalMap.newTreeBackedMap();
        this.binaryTensors = MultiDimensionalMap.newTreeBackedMap();
        this.binaryClassification = MultiDimensionalMap.newTreeBackedMap();
        for (Pair binary : binaryProductions) {
            String right;
            String left = this.basicCategory((String)binary.getFirst());
            if (this.binaryTransform.contains((Object)left, (Object)(right = this.basicCategory((String)binary.getSecond())))) continue;
            this.binaryTransform.put((Object)left, (Object)right, (Object)this.randomTransformMatrix());
            if (this.useDoubleTensors) {
                this.binaryTensors.put((Object)left, (Object)right, (Object)this.randomBinaryINDArray());
            }
            if (this.combineClassification) continue;
            this.binaryClassification.put((Object)left, (Object)right, (Object)this.randomClassificationMatrix());
        }
        this.numBinaryMatrices = this.binaryTransform.size();
        this.binaryTransformSize = this.numHidden * (2 * this.numHidden + 1);
        this.binaryINd4jize = this.useDoubleTensors ? this.numHidden * this.numHidden * this.numHidden * 4 : 0;
        this.binaryClassificationSize = this.combineClassification ? 0 : this.numOuts * (this.numHidden + 1);
        this.unaryClassification = new TreeMap<String, INDArray>();
        for (String unary : unaryProductions) {
            if (this.unaryClassification.containsKey(unary = this.basicCategory(unary))) continue;
            this.unaryClassification.put(unary, this.randomClassificationMatrix());
        }
        this.binaryClassificationSize = this.combineClassification ? 0 : this.numOuts * (this.numHidden + 1);
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numOuts * (this.numHidden + 1);
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numOuts * (this.numHidden + 1);
        this.classWeights = new HashMap<Integer, Double>();
    }

    public Collection<IterationListener> getIterationListeners() {
        return this.iterationListeners;
    }

    public void setIterationListeners(Collection<IterationListener> listeners) {
        this.iterationListeners = listeners != null ? listeners : new ArrayList();
    }

    INDArray randomBinaryINDArray() {
        double range = 1.0f / (4.0f * (float)this.numHidden);
        INDArray ret = Nd4j.rand((int[])new int[]{this.numHidden, this.numHidden * 2, this.numHidden * 2}, (double)(-range), (double)range, (Random)this.rng);
        return ret.muli((Number)this.scalingForInit);
    }

    public INDArray randomTransformMatrix() {
        INDArray binary = Nd4j.create((int)this.numHidden, (int)(this.numHidden * 2 + 1));
        INDArray block = this.randomTransformBlock();
        NDArrayIndex[] indices = new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)block.rows()), NDArrayIndex.interval((int)0, (int)block.columns())};
        binary.put(indices, block);
        NDArrayIndex[] indices2 = new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)block.rows()), NDArrayIndex.interval((int)this.numHidden, (int)(this.numHidden + block.columns()))};
        binary.put(indices2, this.randomTransformBlock());
        if (binary.data().dataType() == DataBuffer.Type.DOUBLE) {
            return Nd4j.getBlasWrapper().scal(this.scalingForInit, binary);
        }
        return Nd4j.getBlasWrapper().scal((float)this.scalingForInit, binary);
    }

    public INDArray randomTransformBlock() {
        double range = 1.0 / (Math.sqrt(this.numHidden) * 2.0);
        INDArray ret = Nd4j.rand((int)this.numHidden, (int)this.numHidden, (double)(-range), (double)range, (Random)this.rng).addi(this.identity);
        return ret;
    }

    INDArray randomClassificationMatrix() {
        double range = 1.0 / Math.sqrt(this.numHidden);
        INDArray ret = Nd4j.zeros((int)this.numOuts, (int)(this.numHidden + 1));
        INDArray insert = Nd4j.rand((int)this.numOuts, (int)this.numHidden, (double)(-range), (double)range, (Random)this.rng);
        ret.put(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.numOuts), NDArrayIndex.interval((int)0, (int)this.numHidden)}, insert);
        if (ret.data().dataType() == DataBuffer.Type.DOUBLE) {
            return Nd4j.getBlasWrapper().scal(this.scalingForInit, ret);
        }
        return Nd4j.getBlasWrapper().scal((float)this.scalingForInit, ret);
    }

    public void fit(List<Tree> trainingBatch) {
        final CountDownLatch c = new CountDownLatch(trainingBatch.size());
        List<Future<Object>> futureBatch = this.fitAsync(trainingBatch);
        for (Future<Object> f : futureBatch) {
            f.onComplete((Function1)new OnComplete<Object>(){

                public void onComplete(Throwable throwable, Object e) throws Throwable {
                    if (throwable != null) {
                        log.warn("Error occurred training batch", throwable);
                    }
                    c.countDown();
                }
            }, (ExecutionContext)this.rnTnActorSystem.dispatcher());
        }
        try {
            c.await();
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    public List<Future<Object>> fitAsync(final List<Tree> trainingBatch) {
        int count = 0;
        ArrayList<Future<Object>> futureBatch = new ArrayList<Future<Object>>();
        for (final Tree t : trainingBatch) {
            log.info("Working mini batch " + count++);
            futureBatch.add((Future<Object>)Futures.future((Callable)new Callable<Object>(){

                @Override
                public Object call() throws Exception {
                    RNTN.this.forwardPropagateTree(t);
                    try {
                        INDArray params = RNTN.this.getParameters();
                        INDArray gradient = RNTN.this.getValueGradient(trainingBatch);
                        if (params.length() != gradient.length()) {
                            throw new IllegalStateException("Params not equal to gradient!");
                        }
                        RNTN.this.setParams(params.subi(gradient));
                    }
                    catch (NegativeArraySizeException e) {
                        log.warn("Couldnt compute parameters due to negative array size...for trees " + t);
                    }
                    return null;
                }
            }, (ExecutionContext)this.rnTnActorSystem.dispatcher()));
        }
        return futureBatch;
    }

    public INDArray getWForNode(Tree node) {
        if (node.children().size() == 2) {
            String leftLabel = ((Tree)node.children().get(0)).value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = ((Tree)node.children().get(1)).value();
            String rightBasic = this.basicCategory(rightLabel);
            return (INDArray)this.binaryTransform.get((Object)leftBasic, (Object)rightBasic);
        }
        if (node.children().size() == 1) {
            throw new AssertionError((Object)"No unary applyTransformToOrigin matrices, only unary classification");
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().size()));
    }

    public INDArray getINDArrayForNode(Tree node) {
        if (!this.useDoubleTensors) {
            throw new AssertionError((Object)"Not using tensors");
        }
        if (node.children().size() == 2) {
            String leftLabel = ((Tree)node.children().get(0)).value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = ((Tree)node.children().get(1)).value();
            String rightBasic = this.basicCategory(rightLabel);
            return (INDArray)this.binaryTensors.get((Object)leftBasic, (Object)rightBasic);
        }
        if (node.children().size() == 1) {
            throw new AssertionError((Object)"No unary transform matrices, only unary classification");
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().size()));
    }

    public INDArray getClassWForNode(Tree node) {
        if (this.combineClassification) {
            return this.unaryClassification.get("");
        }
        if (node.children().size() == 2) {
            String leftLabel = ((Tree)node.children().get(0)).value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = ((Tree)node.children().get(1)).value();
            String rightBasic = this.basicCategory(rightLabel);
            return (INDArray)this.binaryClassification.get((Object)leftBasic, (Object)rightBasic);
        }
        if (node.children().size() == 1) {
            String unaryLabel = ((Tree)node.children().get(0)).value();
            String unaryBasic = this.basicCategory(unaryLabel);
            return this.unaryClassification.get(unaryBasic);
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().size()));
    }

    private INDArray getINDArrayGradient(INDArray deltaFull, INDArray leftVector, INDArray rightVector) {
        int size = deltaFull.length();
        INDArray Wt_df = Nd4j.create((int[])new int[]{size, size * 2, size * 2});
        INDArray fullVector = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
        for (int slice = 0; slice < size; ++slice) {
            if (Wt_df.data().dataType() == DataBuffer.Type.DOUBLE) {
                Wt_df.putSlice(slice, Nd4j.getBlasWrapper().scal(deltaFull.getScalar(slice).getDouble(0), fullVector).mmul(fullVector.transpose()));
                continue;
            }
            Wt_df.putSlice(slice, Nd4j.getBlasWrapper().scal(deltaFull.getScalar(slice).getFloat(0), fullVector).mmul(fullVector.transpose()));
        }
        return Wt_df;
    }

    public INDArray getFeatureVector(String word) {
        INDArray ret = this.featureVectors.vector(this.getVocabWord(word));
        if (ret.isRowVector()) {
            ret = ret.transpose();
        }
        return ret;
    }

    public String getVocabWord(String word) {
        if (this.lowerCasefeatureNames) {
            word = word.toLowerCase();
        }
        if (this.vocabCache.containsWord(word)) {
            return word;
        }
        return "UNK";
    }

    public String basicCategory(String category) {
        if (this.simplifiedModel) {
            return "";
        }
        throw new IllegalStateException("Only simplified model enabled");
    }

    public INDArray getUnaryClassification(String category) {
        category = this.basicCategory(category);
        return this.unaryClassification.get(category);
    }

    public INDArray getBinaryClassification(String left, String right) {
        if (this.combineClassification) {
            return this.unaryClassification.get("");
        }
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return (INDArray)this.binaryClassification.get((Object)left, (Object)right);
    }

    public INDArray getBinaryTransform(String left, String right) {
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return (INDArray)this.binaryTransform.get((Object)left, (Object)right);
    }

    public INDArray getBinaryINDArray(String left, String right) {
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return (INDArray)this.binaryTensors.get((Object)left, (Object)right);
    }

    double scaleAndRegularize(MultiDimensionalMap<String, String, INDArray> derivatives, MultiDimensionalMap<String, String, INDArray> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (MultiDimensionalMap.Entry entry : currentMatrices.entrySet()) {
            INDArray D = (INDArray)derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = D.data().dataType() == DataBuffer.Type.DOUBLE ? Nd4j.getBlasWrapper().scal(scale, D).addi(Nd4j.getBlasWrapper().scal(regCost, (INDArray)entry.getValue())) : Nd4j.getBlasWrapper().scal((float)scale, D).addi(Nd4j.getBlasWrapper().scal((float)regCost, (INDArray)entry.getValue()));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), (Object)D);
            cost += ((INDArray)entry.getValue()).mul((INDArray)entry.getValue()).sum(Integer.MAX_VALUE).getDouble(0) * regCost / 2.0;
        }
        return cost;
    }

    double scaleAndRegularize(Map<String, INDArray> derivatives, Map<String, INDArray> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (String s : currentMatrices.keySet()) {
            INDArray D = derivatives.get(s);
            INDArray vector = currentMatrices.get(s);
            D = D.data().dataType() == DataBuffer.Type.DOUBLE ? Nd4j.getBlasWrapper().scal(scale, D).addi(Nd4j.getBlasWrapper().scal(regCost, vector)) : Nd4j.getBlasWrapper().scal((float)scale, D).addi(Nd4j.getBlasWrapper().scal((float)regCost, vector));
            derivatives.put(s, D);
            cost += vector.mul(vector).sum(Integer.MAX_VALUE).getDouble(0) * regCost / 2.0;
        }
        return cost;
    }

    double scaleAndRegularize(Map<String, INDArray> derivatives, WeightLookupTable currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (String s : this.vocabCache.words()) {
            INDArray D = derivatives.get(s);
            INDArray vector = currentMatrices.vector(s);
            D = D.data().dataType() == DataBuffer.Type.DOUBLE ? Nd4j.getBlasWrapper().scal(scale, D).addi(Nd4j.getBlasWrapper().scal(regCost, vector)) : Nd4j.getBlasWrapper().scal((float)scale, D).addi(Nd4j.getBlasWrapper().scal((float)regCost, vector));
            derivatives.put(s, D);
            cost += vector.mul(vector).sum(Integer.MAX_VALUE).getDouble(0) * regCost / 2.0;
        }
        return cost;
    }

    double scaleAndRegularizeINDArray(MultiDimensionalMap<String, String, INDArray> derivatives, MultiDimensionalMap<String, String, INDArray> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (MultiDimensionalMap.Entry entry : currentMatrices.entrySet()) {
            INDArray D = (INDArray)derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = D.muli((Number)scale).add(((INDArray)entry.getValue()).muli((Number)regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), (Object)D);
            cost += ((INDArray)entry.getValue()).mul((INDArray)entry.getValue()).sum(Integer.MAX_VALUE).getDouble(0) * regCost / 2.0;
        }
        return cost;
    }

    private void backpropDerivativesAndError(Tree tree, MultiDimensionalMap<String, String, INDArray> binaryTD, MultiDimensionalMap<String, String, INDArray> binaryCD, MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD, Map<String, INDArray> unaryCD, Map<String, INDArray> wordVectorD) {
        INDArray delta = Nd4j.create((int)this.numHidden, (int)1);
        this.backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, delta);
    }

    private void backpropDerivativesAndError(Tree tree, MultiDimensionalMap<String, String, INDArray> binaryTD, MultiDimensionalMap<String, String, INDArray> binaryCD, MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD, Map<String, INDArray> unaryCD, Map<String, INDArray> wordVectorD, INDArray deltaUp) {
        Double nodeWeight;
        if (tree.isLeaf()) {
            return;
        }
        INDArray currentVector = tree.vector();
        String category = tree.label();
        category = this.basicCategory(category);
        INDArray goldLabel = Nd4j.create((int)this.numOuts, (int)1);
        int goldClass = tree.goldLabel();
        if (goldClass >= 0) {
            assert (goldClass <= this.numOuts) : "Tried adding a label that was >= to the number of configured outputs " + this.numOuts + " with label " + goldClass;
            goldLabel.putScalar(goldClass, 1.0f);
        }
        if ((nodeWeight = this.classWeights.get(goldClass)) == null) {
            nodeWeight = 1.0;
        }
        INDArray predictions = tree.prediction();
        INDArray deltaClass = null;
        deltaClass = predictions.data().dataType() == DataBuffer.Type.DOUBLE ? (goldClass >= 0 ? Nd4j.getBlasWrapper().scal(nodeWeight.doubleValue(), predictions.sub(goldLabel)) : Nd4j.create((int)predictions.rows(), (int)predictions.columns())) : (goldClass >= 0 ? Nd4j.getBlasWrapper().scal((float)nodeWeight.doubleValue(), predictions.sub(goldLabel)) : Nd4j.create((int)predictions.rows(), (int)predictions.columns()));
        INDArray localCD = deltaClass.mmul(Nd4j.appendBias((INDArray[])new INDArray[]{currentVector}).transpose());
        double error = -Transforms.log((INDArray)predictions).muli(goldLabel).sum(Integer.MAX_VALUE).getDouble(0);
        tree.setError(error *= nodeWeight.doubleValue());
        if (tree.isPreTerminal()) {
            unaryCD.put(category, unaryCD.get(category).add(localCD));
            String word = ((Tree)tree.children().get(0)).label();
            word = this.getVocabWord(word);
            INDArray currentVectorDerivative = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.activationFunction, currentVector));
            INDArray deltaFromClass = this.getUnaryClassification(category).transpose().mmul(deltaClass);
            deltaFromClass = deltaFromClass.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.numHidden), NDArrayIndex.interval((int)0, (int)1)}).mul(currentVectorDerivative);
            INDArray deltaFull = deltaFromClass.add(deltaUp);
            INDArray wordVector = wordVectorD.get(word);
            wordVectorD.put(word, wordVector.add(deltaFull));
        } else {
            INDArray deltaDown;
            String leftCategory = this.basicCategory(((Tree)tree.children().get(0)).label());
            String rightCategory = this.basicCategory(((Tree)tree.children().get(1)).label());
            if (this.combineClassification) {
                unaryCD.put("", unaryCD.get("").add(localCD));
            } else {
                binaryCD.put((Object)leftCategory, (Object)rightCategory, (Object)((INDArray)binaryCD.get((Object)leftCategory, (Object)rightCategory)).add(localCD));
            }
            INDArray currentVectorDerivative = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.activationFunction, currentVector));
            INDArray deltaFromClass = this.getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass);
            INDArray mult = deltaFromClass.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.numHidden), NDArrayIndex.interval((int)0, (int)1)});
            deltaFromClass = mult.muli(currentVectorDerivative);
            INDArray deltaFull = deltaFromClass.add(deltaUp);
            INDArray leftVector = ((Tree)tree.children().get(0)).vector();
            INDArray rightVector = ((Tree)tree.children().get(1)).vector();
            INDArray childrenVector = Nd4j.appendBias((INDArray[])new INDArray[]{leftVector, rightVector});
            INDArray add = (INDArray)binaryTD.get((Object)leftCategory, (Object)rightCategory);
            INDArray W_df = deltaFromClass.mmul(childrenVector.transpose());
            binaryTD.put((Object)leftCategory, (Object)rightCategory, (Object)add.add(W_df));
            if (this.useDoubleTensors) {
                INDArray Wt_df = this.getINDArrayGradient(deltaFull, leftVector, rightVector);
                binaryINDArrayTD.put((Object)leftCategory, (Object)rightCategory, (Object)((INDArray)binaryINDArrayTD.get((Object)leftCategory, (Object)rightCategory)).add(Wt_df));
                deltaDown = this.computeINDArrayDeltaDown(deltaFull, leftVector, rightVector, this.getBinaryTransform(leftCategory, rightCategory), this.getBinaryINDArray(leftCategory, rightCategory));
            } else {
                deltaDown = this.getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull);
            }
            INDArray leftDerivative = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.activationFunction, leftVector));
            INDArray rightDerivative = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.activationFunction, rightVector));
            INDArray leftDeltaDown = deltaDown.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)deltaFull.rows()), NDArrayIndex.interval((int)0, (int)1)});
            INDArray rightDeltaDown = deltaDown.get(new NDArrayIndex[]{NDArrayIndex.interval((int)deltaFull.rows(), (int)(deltaFull.rows() * 2)), NDArrayIndex.interval((int)0, (int)1)});
            this.backpropDerivativesAndError((Tree)tree.children().get(0), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown));
            this.backpropDerivativesAndError((Tree)tree.children().get(1), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown));
        }
    }

    private INDArray computeINDArrayDeltaDown(INDArray deltaFull, INDArray leftVector, INDArray rightVector, INDArray W, INDArray Wt) {
        INDArray WTDelta = W.transpose().mmul(deltaFull);
        INDArray WTDeltaNoBias = WTDelta.isMatrix() ? WTDelta.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)1), NDArrayIndex.interval((int)0, (int)(deltaFull.rows() * 2 + 1))}) : WTDelta.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(deltaFull.rows() * 2))});
        int size = deltaFull.length();
        INDArray deltaINDArray = Nd4j.create((int)(size * 2), (int)1);
        INDArray fullVector = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
        for (int slice = 0; slice < size; ++slice) {
            INDArray scaledFullVector;
            if (deltaFull.data().dataType() == DataBuffer.Type.DOUBLE) {
                scaledFullVector = Nd4j.getBlasWrapper().scal(deltaFull.getScalar(slice).getDouble(0), fullVector);
                deltaINDArray = deltaINDArray.add(Wt.slice(slice).add(Wt.slice(slice).transpose()).mmul(scaledFullVector));
                continue;
            }
            scaledFullVector = Nd4j.getBlasWrapper().scal((float)deltaFull.getScalar(slice).getDouble(0), fullVector);
            deltaINDArray.addi(Wt.slice(slice).add(Wt.slice(slice).transpose()).mmul(scaledFullVector));
        }
        return deltaINDArray.add(WTDeltaNoBias);
    }

    public void forwardPropagateTree(Tree tree) {
        INDArray nodeVector;
        INDArray classification;
        if (tree.isLeaf()) {
            throw new AssertionError((Object)"We should not have reached leaves in forwardPropagate");
        }
        if (tree.isPreTerminal()) {
            classification = this.getUnaryClassification(tree.label());
            String word = ((Tree)tree.children().get(0)).value();
            INDArray wordVector = this.getFeatureVector(word);
            if (wordVector == null) {
                wordVector = this.featureVectors.vector("UNK");
            }
            nodeVector = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.activationFunction, wordVector));
        } else {
            if (tree.children().size() == 1) {
                throw new AssertionError((Object)"Non-preterminal nodes of size 1 should have already been collapsed");
            }
            if (tree.children().size() == 2) {
                Tree left = tree.firstChild();
                Tree right = tree.lastChild();
                this.forwardPropagateTree(left);
                this.forwardPropagateTree(right);
                String leftCategory = ((Tree)tree.children().get(0)).label();
                String rightCategory = ((Tree)tree.children().get(1)).label();
                INDArray W = this.getBinaryTransform(leftCategory, rightCategory);
                classification = this.getBinaryClassification(leftCategory, rightCategory);
                INDArray leftVector = ((Tree)tree.children().get(0)).vector();
                INDArray rightVector = ((Tree)tree.children().get(1)).vector();
                INDArray childrenVector = Nd4j.appendBias((INDArray[])new INDArray[]{leftVector, rightVector});
                if (this.useDoubleTensors) {
                    INDArray doubleT = this.getBinaryINDArray(leftCategory, rightCategory);
                    INDArray INDArrayIn = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
                    INDArray INDArrayOut = Nd4j.bilinearProducts((INDArray)doubleT, (INDArray)INDArrayIn);
                    nodeVector = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.activationFunction, W.mmul(childrenVector).addi(INDArrayOut)));
                } else {
                    nodeVector = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.activationFunction, W.mmul(childrenVector)));
                }
            } else {
                throw new AssertionError((Object)"Tree not correctly binarized");
            }
        }
        INDArray inputWithBias = Nd4j.appendBias((INDArray[])new INDArray[]{nodeVector});
        if (inputWithBias.rows() != classification.columns()) {
            inputWithBias = inputWithBias.transpose();
        }
        INDArray preAct = classification.mmul(inputWithBias);
        INDArray predictions = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(this.outputActivation, preAct));
        tree.setPrediction(predictions);
        tree.setVector(nodeVector);
    }

    private INDArray getDoubleTensorGradient(INDArray deltaFull, INDArray leftVector, INDArray rightVector) {
        int size = deltaFull.length();
        INDArray Wt_df = Nd4j.create((int[])new int[]{size * 2, size * 2, size});
        INDArray fullVector = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
        for (int slice = 0; slice < size; ++slice) {
            if (Wt_df.data().dataType() == DataBuffer.Type.DOUBLE) {
                Wt_df.putSlice(slice, Nd4j.getBlasWrapper().scal(deltaFull.getDouble(slice), fullVector).mmul(fullVector.transpose()));
                continue;
            }
            Wt_df.putSlice(slice, Nd4j.getBlasWrapper().scal((float)deltaFull.getDouble(slice), fullVector).mmul(fullVector.transpose()));
        }
        return Wt_df;
    }

    public List<INDArray> output(List<Tree> trees) {
        ArrayList<INDArray> ret = new ArrayList<INDArray>();
        for (Tree t : trees) {
            this.forwardPropagateTree(t);
            ret.add(t.prediction());
        }
        return ret;
    }

    public List<Integer> predict(List<Tree> trees) {
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (Tree t : trees) {
            this.forwardPropagateTree(t);
            ret.add(Nd4j.getBlasWrapper().iamax(t.prediction()));
        }
        return ret;
    }

    public void setParams(INDArray params) {
        if (params.length() != this.getNumParameters()) {
            throw new IllegalStateException("Unable to set parameters of length " + params.length() + " must be of length " + this.numParameters);
        }
        Nd4j.setParams((INDArray)params, (Iterator[])new Iterator[]{this.binaryTransform.values().iterator(), this.binaryClassification.values().iterator(), this.binaryTensors.values().iterator(), this.unaryClassification.values().iterator(), this.featureVectors.vectors()});
    }

    public int getNumParameters() {
        if (this.numParameters < 0) {
            int totalSize = 0;
            List<Iterator> list = Arrays.asList(this.binaryTransform.values().iterator(), this.binaryClassification.values().iterator(), this.binaryTensors.values().iterator(), this.unaryClassification.values().iterator(), this.featureVectors.vectors());
            for (Iterator iter : list) {
                while (iter.hasNext()) {
                    totalSize += ((INDArray)iter.next()).length();
                }
            }
            this.numParameters = totalSize;
        }
        return this.numParameters;
    }

    public INDArray getParameters() {
        return Nd4j.toFlattened((int)this.getNumParameters(), (Iterator[])new Iterator[]{this.binaryTransform.values().iterator(), this.binaryClassification.values().iterator(), this.binaryTensors.values().iterator(), this.unaryClassification.values().iterator(), this.featureVectors.vectors()});
    }

    public INDArray getValueGradient(List<Tree> trainingBatch) {
        int numCols;
        int numRows;
        final MultiDimensionalMap binaryTD = MultiDimensionalMap.newTreeBackedMap();
        final MultiDimensionalMap binaryINDArrayTD = MultiDimensionalMap.newTreeBackedMap();
        final MultiDimensionalMap binaryCD = MultiDimensionalMap.newTreeBackedMap();
        final TreeMap<String, INDArray> unaryCD = new TreeMap<String, INDArray>();
        final TreeMap<String, INDArray> wordVectorD = new TreeMap<String, INDArray>();
        for (MultiDimensionalMap.Entry entry : this.binaryTransform.entrySet()) {
            numRows = ((INDArray)entry.getValue()).rows();
            numCols = ((INDArray)entry.getValue()).columns();
            binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), (Object)Nd4j.create((int)numRows, (int)numCols));
        }
        if (!this.combineClassification) {
            for (MultiDimensionalMap.Entry entry : this.binaryClassification.entrySet()) {
                numRows = ((INDArray)entry.getValue()).rows();
                numCols = ((INDArray)entry.getValue()).columns();
                binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), (Object)Nd4j.create((int)numRows, (int)numCols));
            }
        }
        if (this.useDoubleTensors) {
            for (MultiDimensionalMap.Entry entry : this.binaryTensors.entrySet()) {
                numRows = ((INDArray)entry.getValue()).size(1);
                numCols = ((INDArray)entry.getValue()).size(2);
                int numSlices = ((INDArray)entry.getValue()).slices();
                binaryINDArrayTD.put(entry.getFirstKey(), entry.getSecondKey(), (Object)Nd4j.create((int[])new int[]{numRows, numCols, numSlices}));
            }
        }
        for (Map.Entry entry : this.unaryClassification.entrySet()) {
            numRows = ((INDArray)entry.getValue()).rows();
            numCols = ((INDArray)entry.getValue()).columns();
            unaryCD.put((String)entry.getKey(), Nd4j.create((int)numRows, (int)numCols));
        }
        for (String string : this.vocabCache.words()) {
            INDArray vector = this.featureVectors.vector(string);
            int numRows2 = vector.rows();
            int numCols2 = vector.columns();
            wordVectorD.put(string, Nd4j.create((int)numRows2, (int)numCols2));
        }
        final CopyOnWriteArrayList forwardPropTrees = new CopyOnWriteArrayList();
        Parallelization.iterateInParallel(trainingBatch, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<Tree>(){

            public void run(Tree currentItem, Object[] args) {
                Tree trainingTree = new Tree(currentItem);
                trainingTree.connect(new ArrayList(currentItem.children()));
                RNTN.this.forwardPropagateTree(trainingTree);
                forwardPropTrees.add(trainingTree);
            }
        }, (ActorSystem)this.rnTnActorSystem);
        final AtomicDouble atomicDouble = new AtomicDouble(0.0);
        if (!forwardPropTrees.isEmpty()) {
            Parallelization.iterateInParallel(forwardPropTrees, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<Tree>(){

                public void run(Tree currentItem, Object[] args) {
                    RNTN.this.backpropDerivativesAndError(currentItem, (MultiDimensionalMap<String, String, INDArray>)binaryTD, (MultiDimensionalMap<String, String, INDArray>)binaryCD, (MultiDimensionalMap<String, String, INDArray>)binaryINDArrayTD, unaryCD, wordVectorD);
                    atomicDouble.addAndGet(currentItem.errorSum());
                }
            }, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<Tree>(){

                public void run(Tree currentItem, Object[] args) {
                }
            }, (ActorSystem)this.rnTnActorSystem, (Object[])new Object[]{binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD});
        }
        double scale = trainingBatch == null || trainingBatch.isEmpty() ? 1.0 : (double)(1.0f / (float)trainingBatch.size());
        this.value = atomicDouble.doubleValue() * scale;
        this.value += this.scaleAndRegularize((MultiDimensionalMap<String, String, INDArray>)binaryTD, this.binaryTransform, scale, this.regTransformMatrix);
        this.value += this.scaleAndRegularize((MultiDimensionalMap<String, String, INDArray>)binaryCD, this.binaryClassification, scale, this.regClassification);
        this.value += this.scaleAndRegularizeINDArray((MultiDimensionalMap<String, String, INDArray>)binaryINDArrayTD, this.binaryTensors, scale, this.regTransformINDArray);
        this.value += this.scaleAndRegularize(unaryCD, this.unaryClassification, scale, this.regClassification);
        this.value += this.scaleAndRegularize(wordVectorD, this.featureVectors, scale, this.regWordVector);
        INDArray derivative = Nd4j.toFlattened((int)this.getNumParameters(), (Iterator[])new Iterator[]{binaryTD.values().iterator(), binaryCD.values().iterator(), binaryINDArrayTD.values().iterator(), unaryCD.values().iterator(), wordVectorD.values().iterator()});
        if (derivative.length() != this.numParameters) {
            throw new IllegalStateException("Gradient has wrong number of parameters " + derivative.length() + " should have been " + this.numParameters);
        }
        if (this.paramAdaGrad == null) {
            this.paramAdaGrad = new AdaGrad(1, derivative.columns());
        }
        derivative = this.paramAdaGrad.getGradient(derivative);
        return derivative;
    }

    public double getValue() {
        return this.value;
    }

    public void fit() {
    }

    public void update(Gradient gradient) {
    }

    public double score() {
        return 0.0;
    }

    public INDArray transform(INDArray data) {
        return null;
    }

    public INDArray params() {
        return this.getParameters();
    }

    public int numParams() {
        return this.getNumParameters();
    }

    public void fit(INDArray data) {
        throw new UnsupportedOperationException();
    }

    public void iterate(INDArray input) {
        throw new UnsupportedOperationException();
    }

    public Gradient gradient() {
        return null;
    }

    public Pair<Gradient, Double> gradientAndScore() {
        return null;
    }

    public int batchSize() {
        return 0;
    }

    public Layer.Type type() {
        return Layer.Type.RECURSIVE;
    }

    public Gradient error(INDArray input) {
        return null;
    }

    public INDArray derivativeActivation(INDArray input) {
        return null;
    }

    public Gradient calcGradient(Gradient layerError, INDArray indArray) {
        return null;
    }

    public Gradient errorSignal(Gradient error, INDArray input) {
        return null;
    }

    public Gradient backwardGradient(INDArray activation, Gradient errorSignal) {
        return null;
    }

    public void merge(Layer layer, int batchSize) {
    }

    public INDArray getParam(String param) {
        return null;
    }

    public void initParams() {
    }

    public Map<String, INDArray> paramTable() {
        return null;
    }

    public void setParamTable(Map<String, INDArray> paramTable) {
    }

    public void setParam(String key, INDArray val) {
    }

    public void clear() {
    }

    public INDArray activationMean() {
        return null;
    }

    public NeuralNetConfiguration conf() {
        return null;
    }

    public INDArray preOutput(INDArray x) {
        return null;
    }

    public INDArray activate() {
        return null;
    }

    public INDArray activate(INDArray input) {
        return null;
    }

    public Layer transpose() {
        return null;
    }

    public Layer clone() {
        return null;
    }

    public Pair<Gradient, Gradient> backWard(Gradient errors, Gradient deltas, INDArray activation, String previousActivation) {
        return null;
    }

    public void setScore() {
    }

    public void accumulateScore(double accum) {
    }

    public void setConf(NeuralNetConfiguration conf) {
    }

    public INDArray input() {
        return null;
    }

    public void validateInput() {
    }

    public ConvexOptimizer getOptimizer() {
        return null;
    }

    public static class Builder {
        private int numHidden;
        private Random rng;
        private boolean useINd4j;
        private boolean combineClassification = true;
        private boolean simplifiedModel = true;
        private boolean randomFeatureVectors;
        private double scalingForInit = 0.001f;
        private boolean lowerCasefeatureNames;
        private String activationFunction = "sigmoid";
        private String outputActivationFunction = "softmax";
        private int adagradResetFrequency;
        private double regTransformINDArray;
        private WeightLookupTable featureVectors;
        private int numBinaryMatrices;
        private int binaryTransformSize;
        private int binaryINd4jize;
        private int binaryClassificationSize;
        private int numUnaryMatrices;
        private int unaryClassificationSize;
        private Map<Integer, Double> classWeights;
        private VocabCache vocabCache;

        public Builder vocabCache(VocabCache vocabCache) {
            this.vocabCache = vocabCache;
            return this;
        }

        public Builder withOutputActivation(String outputActivationFunction) {
            this.outputActivationFunction = outputActivationFunction;
            return this;
        }

        public Builder setFeatureVectors(Word2Vec vec) {
            this.vocabCache = vec.vocab();
            return this.setFeatureVectors(vec.lookupTable());
        }

        public Builder setNumHidden(int numHidden) {
            this.numHidden = numHidden;
            return this;
        }

        public Builder setRng(Random rng) {
            this.rng = rng;
            return this;
        }

        public Builder setUseTensors(boolean useINd4j) {
            this.useINd4j = useINd4j;
            return this;
        }

        public Builder setCombineClassification(boolean combineClassification) {
            this.combineClassification = combineClassification;
            return this;
        }

        public Builder setSimplifiedModel(boolean simplifiedModel) {
            this.simplifiedModel = simplifiedModel;
            return this;
        }

        public Builder setRandomFeatureVectors(boolean randomFeatureVectors) {
            this.randomFeatureVectors = randomFeatureVectors;
            return this;
        }

        public Builder setScalingForInit(double scalingForInit) {
            this.scalingForInit = scalingForInit;
            return this;
        }

        public Builder setLowerCasefeatureNames(boolean lowerCasefeatureNames) {
            this.lowerCasefeatureNames = lowerCasefeatureNames;
            return this;
        }

        public Builder setActivationFunction(String activationFunction) {
            this.activationFunction = activationFunction;
            return this;
        }

        public Builder setAdagradResetFrequency(int adagradResetFrequency) {
            this.adagradResetFrequency = adagradResetFrequency;
            return this;
        }

        public Builder setRegTransformINDArray(double regTransformINDArray) {
            this.regTransformINDArray = regTransformINDArray;
            return this;
        }

        public Builder setFeatureVectors(WeightLookupTable featureVectors) {
            this.featureVectors = featureVectors;
            this.numHidden = featureVectors.vectors().next().columns();
            return this;
        }

        public Builder setNumBinaryMatrices(int numBinaryMatrices) {
            this.numBinaryMatrices = numBinaryMatrices;
            return this;
        }

        public Builder setBinaryTransformSize(int binaryTransformSize) {
            this.binaryTransformSize = binaryTransformSize;
            return this;
        }

        public Builder setBinaryINd4jize(int binaryINd4jize) {
            this.binaryINd4jize = binaryINd4jize;
            return this;
        }

        public Builder setBinaryClassificationSize(int binaryClassificationSize) {
            this.binaryClassificationSize = binaryClassificationSize;
            return this;
        }

        public Builder setNumUnaryMatrices(int numUnaryMatrices) {
            this.numUnaryMatrices = numUnaryMatrices;
            return this;
        }

        public Builder setUnaryClassificationSize(int unaryClassificationSize) {
            this.unaryClassificationSize = unaryClassificationSize;
            return this;
        }

        public Builder setClassWeights(Map<Integer, Double> classWeights) {
            this.classWeights = classWeights;
            return this;
        }

        public RNTN build() {
            RNTN rt = new RNTN(this.numHidden, this.rng, this.useINd4j, this.combineClassification, this.simplifiedModel, this.randomFeatureVectors, this.scalingForInit, this.lowerCasefeatureNames, this.activationFunction, this.adagradResetFrequency, this.regTransformINDArray, this.featureVectors, this.vocabCache, this.numBinaryMatrices, this.binaryTransformSize, this.binaryINd4jize, this.binaryClassificationSize, this.numUnaryMatrices, this.unaryClassificationSize, this.classWeights);
            return rt;
        }
    }
}

