/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.rntn;

import akka.actor.ActorSystem;
import com.google.common.util.concurrent.AtomicDouble;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.rntn.Tree;
import org.deeplearning4j.models.word2vec.Word2Vec;
import org.deeplearning4j.parallel.Parallelization;
import org.deeplearning4j.util.MultiDimensionalMap;
import org.deeplearning4j.util.MultiDimensionalSet;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RNTN
implements Serializable {
    protected double value = 0.0;
    private int numOuts = 3;
    private int numHidden = 25;
    private RandomGenerator rng;
    private boolean useFloatTensors = true;
    private boolean combineClassification = true;
    private boolean simplifiedModel = true;
    private boolean randomFeatureVectors = true;
    private double scalingForInit = 1.0;
    public static final String UNKNOWN_FEATURE = "UNK";
    private boolean lowerCasefeatureNames;
    protected ActivationFunction activationFunction = Activations.tanh();
    protected ActivationFunction outputActivation = Activations.softMaxRows();
    protected AdaGrad paramAdaGrad;
    private double regTransformMatrix = 0.001f;
    private double regClassification = 1.0E-4f;
    private double regWordVector = 1.0E-4f;
    private int adagradResetFrequency = 1;
    private double regTransformINDArray = 0.001f;
    private MultiDimensionalMap<String, String, INDArray> binaryTransform;
    private MultiDimensionalMap<String, String, INDArray> binaryINd4j;
    private Map<String, INDArray> unaryClassification;
    private Map<String, INDArray> featureVectors;
    private MultiDimensionalMap<String, String, INDArray> binaryClassification;
    private int numBinaryMatrices;
    private int binaryTransformSize;
    private int binaryINd4jize;
    private int binaryClassificationSize;
    private int numUnaryMatrices;
    private int unaryClassificationSize;
    private INDArray identity;
    private List<Tree> trainingTrees;
    private Map<Integer, Float> classWeights;
    private static Logger log = LoggerFactory.getLogger(RNTN.class);
    private transient ActorSystem rnTnActorSystem = ActorSystem.create((String)"RNTN");

    private RNTN(int numHidden, RandomGenerator rng, boolean useFloatTensors, boolean combineClassification, boolean simplifiedModel, boolean randomFeatureVectors, double scalingForInit, boolean lowerCasefeatureNames, ActivationFunction activationFunction, int adagradResetFrequency, double regTransformINDArray, Map<String, INDArray> featureVectors, int numBinaryMatrices, int binaryTransformSize, int binaryINd4jize, int binaryClassificationSize, int numUnaryMatrices, int unaryClassificationSize, Map<Integer, Float> classWeights) {
        this.numHidden = numHidden;
        this.rng = rng;
        this.useFloatTensors = useFloatTensors;
        this.combineClassification = combineClassification;
        this.simplifiedModel = simplifiedModel;
        this.randomFeatureVectors = randomFeatureVectors;
        this.scalingForInit = scalingForInit;
        this.lowerCasefeatureNames = lowerCasefeatureNames;
        this.activationFunction = activationFunction;
        this.adagradResetFrequency = adagradResetFrequency;
        this.regTransformINDArray = regTransformINDArray;
        this.featureVectors = featureVectors;
        this.numBinaryMatrices = numBinaryMatrices;
        this.binaryTransformSize = binaryTransformSize;
        this.binaryINd4jize = binaryINd4jize;
        this.binaryClassificationSize = binaryClassificationSize;
        this.numUnaryMatrices = numUnaryMatrices;
        this.unaryClassificationSize = unaryClassificationSize;
        this.classWeights = classWeights;
        this.init();
    }

    private void init() {
        if (this.rng == null) {
            this.rng = new MersenneTwister(123);
        }
        MultiDimensionalSet binaryProductions = MultiDimensionalSet.hashSet();
        if (!this.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        binaryProductions.add((Object)"", (Object)"");
        HashSet<String> unaryProductions = new HashSet<String>();
        if (!this.simplifiedModel) {
            throw new UnsupportedOperationException("Not yet implemented");
        }
        unaryProductions.add("");
        this.identity = Nd4j.eye((int)this.numHidden);
        this.binaryTransform = MultiDimensionalMap.newTreeBackedMap();
        this.binaryINd4j = MultiDimensionalMap.newTreeBackedMap();
        this.binaryClassification = MultiDimensionalMap.newTreeBackedMap();
        for (Pair binary : binaryProductions) {
            String right;
            String left = this.basicCategory((String)binary.getFirst());
            if (this.binaryTransform.contains((Object)left, (Object)(right = this.basicCategory((String)binary.getSecond())))) continue;
            this.binaryTransform.put((Object)left, (Object)right, (Object)this.randomTransformMatrix());
            if (this.useFloatTensors) {
                this.binaryINd4j.put((Object)left, (Object)right, (Object)this.randomBinaryINDArray());
            }
            if (this.combineClassification) continue;
            this.binaryClassification.put((Object)left, (Object)right, (Object)this.randomClassificationMatrix());
        }
        this.numBinaryMatrices = this.binaryTransform.size();
        this.binaryTransformSize = this.numHidden * (2 * this.numHidden + 1);
        this.binaryINd4jize = this.useFloatTensors ? this.numHidden * this.numHidden * this.numHidden * 4 : 0;
        this.binaryClassificationSize = this.combineClassification ? 0 : this.numOuts * (this.numHidden + 1);
        this.unaryClassification = new TreeMap<String, INDArray>();
        for (String unary : unaryProductions) {
            if (this.unaryClassification.containsKey(unary = this.basicCategory(unary))) continue;
            this.unaryClassification.put(unary, this.randomClassificationMatrix());
        }
        this.binaryClassificationSize = this.combineClassification ? 0 : this.numOuts * (this.numHidden + 1);
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numOuts * (this.numHidden + 1);
        this.featureVectors.put(UNKNOWN_FEATURE, this.randomWordVector());
        this.numUnaryMatrices = this.unaryClassification.size();
        this.unaryClassificationSize = this.numOuts * (this.numHidden + 1);
        this.classWeights = new HashMap<Integer, Float>();
    }

    INDArray randomBinaryINDArray() {
        double range = 1.0f / (4.0f * (float)this.numHidden);
        INDArray ret = Nd4j.rand((int[])new int[]{this.numHidden, this.numHidden * 2, this.numHidden * 2}, (double)(-range), (double)range, (RandomGenerator)this.rng);
        return ret.muli((Number)this.scalingForInit);
    }

    public INDArray randomTransformMatrix() {
        INDArray binary = Nd4j.create((int)this.numHidden, (int)(this.numHidden * 2 + 1));
        INDArray block = this.randomTransformBlock();
        NDArrayIndex[] indices = new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)block.rows()), NDArrayIndex.interval((int)0, (int)block.columns())};
        binary.put(indices, block);
        NDArrayIndex[] indices2 = new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)block.rows()), NDArrayIndex.interval((int)this.numHidden, (int)(this.numHidden + block.columns()))};
        binary.put(indices2, this.randomTransformBlock());
        return Nd4j.getBlasWrapper().scal(this.scalingForInit, binary);
    }

    public INDArray randomTransformBlock() {
        double range = 1.0 / (Math.sqrt(this.numHidden) * 2.0);
        INDArray ret = Nd4j.rand((int)this.numHidden, (int)this.numHidden, (double)(-range), (double)range, (RandomGenerator)this.rng).add(this.identity);
        return ret;
    }

    INDArray randomClassificationMatrix() {
        double range = 1.0 / Math.sqrt(this.numHidden);
        INDArray ret = Nd4j.zeros((int)this.numOuts, (int)(this.numHidden + 1));
        INDArray insert = Nd4j.rand((int)this.numOuts, (int)this.numHidden, (double)(-range), (double)range, (RandomGenerator)this.rng);
        ret.put(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.numOuts), NDArrayIndex.interval((int)0, (int)this.numHidden)}, insert);
        return Nd4j.getBlasWrapper().scal(this.scalingForInit, ret);
    }

    INDArray randomWordVector() {
        return Nd4j.rand((int)this.numHidden, (int)1, (RandomGenerator)this.rng);
    }

    public void fit(List<Tree> trainingBatch) {
        this.trainingTrees = trainingBatch;
        for (Tree t : trainingBatch) {
            this.forwardPropagateTree(t);
            this.setParameters(this.getParameters().subi(this.getValueGradient(0)));
        }
    }

    public void setParams(INDArray theta, Iterator<? extends INDArray> ... matrices) {
        int index = 0;
        for (Iterator<? extends INDArray> matrixIterator : matrices) {
            while (matrixIterator.hasNext()) {
                INDArray matrix = matrixIterator.next();
                for (int i = 0; i < matrix.length(); ++i) {
                    matrix.put(i, theta.getScalar(index));
                    ++index;
                }
            }
        }
        if (index != theta.length()) {
            throw new AssertionError((Object)"Did not entirely use the theta vector");
        }
    }

    public INDArray getWForNode(Tree node) {
        if (node.children().size() == 2) {
            String leftLabel = node.children().get(0).value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children().get(1).value();
            String rightBasic = this.basicCategory(rightLabel);
            return (INDArray)this.binaryTransform.get((Object)leftBasic, (Object)rightBasic);
        }
        if (node.children().size() == 1) {
            throw new AssertionError((Object)"No unary applyTransformToOrigin matrices, only unary classification");
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().size()));
    }

    public INDArray getINDArrayForNode(Tree node) {
        if (!this.useFloatTensors) {
            throw new AssertionError((Object)"Not using INd4j");
        }
        if (node.children().size() == 2) {
            String leftLabel = node.children().get(0).value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children().get(1).value();
            String rightBasic = this.basicCategory(rightLabel);
            return (INDArray)this.binaryINd4j.get((Object)leftBasic, (Object)rightBasic);
        }
        if (node.children().size() == 1) {
            throw new AssertionError((Object)"No unary applyTransformToOrigin matrices, only unary classification");
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().size()));
    }

    public INDArray getClassWForNode(Tree node) {
        if (this.combineClassification) {
            return this.unaryClassification.get("");
        }
        if (node.children().size() == 2) {
            String leftLabel = node.children().get(0).value();
            String leftBasic = this.basicCategory(leftLabel);
            String rightLabel = node.children().get(1).value();
            String rightBasic = this.basicCategory(rightLabel);
            return (INDArray)this.binaryClassification.get((Object)leftBasic, (Object)rightBasic);
        }
        if (node.children().size() == 1) {
            String unaryLabel = node.children().get(0).value();
            String unaryBasic = this.basicCategory(unaryLabel);
            return this.unaryClassification.get(unaryBasic);
        }
        throw new AssertionError((Object)("Unexpected tree children size of " + node.children().size()));
    }

    private INDArray getINDArrayGradient(INDArray deltaFull, INDArray leftVector, INDArray rightVector) {
        int size = deltaFull.length();
        INDArray Wt_df = Nd4j.create((int[])new int[]{size, size * 2, size * 2});
        INDArray fullVector = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
        for (int slice = 0; slice < size; ++slice) {
            Wt_df.putSlice(slice, Nd4j.getBlasWrapper().scal(((Double)deltaFull.getScalar(slice).element()).doubleValue(), fullVector).mmul(fullVector.transpose()));
        }
        return Wt_df;
    }

    public INDArray getFeatureVector(String word) {
        INDArray ret = this.featureVectors.get(this.getVocabWord(word));
        if (ret.isRowVector()) {
            ret = ret.transpose();
        }
        return ret;
    }

    public String getVocabWord(String word) {
        if (this.lowerCasefeatureNames) {
            word = word.toLowerCase();
        }
        if (this.featureVectors.containsKey(word)) {
            return word;
        }
        return UNKNOWN_FEATURE;
    }

    public String basicCategory(String category) {
        if (this.simplifiedModel) {
            return "";
        }
        throw new IllegalStateException("Only simplified model enabled");
    }

    public INDArray getUnaryClassification(String category) {
        category = this.basicCategory(category);
        return this.unaryClassification.get(category);
    }

    public INDArray getBinaryClassification(String left, String right) {
        if (this.combineClassification) {
            return this.unaryClassification.get("");
        }
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return (INDArray)this.binaryClassification.get((Object)left, (Object)right);
    }

    public INDArray getBinaryTransform(String left, String right) {
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return (INDArray)this.binaryTransform.get((Object)left, (Object)right);
    }

    public INDArray getBinaryINDArray(String left, String right) {
        left = this.basicCategory(left);
        right = this.basicCategory(right);
        return (INDArray)this.binaryINd4j.get((Object)left, (Object)right);
    }

    public int getNumParameters() {
        int totalSize = this.numBinaryMatrices * (this.binaryTransform.size() + this.binaryClassificationSize) + this.binaryINd4jize;
        totalSize += this.numUnaryMatrices * this.unaryClassification.size();
        return totalSize += this.featureVectors.size() * this.numHidden;
    }

    public INDArray getParameters() {
        return Nd4j.toFlattened((int)this.getNumParameters(), (Iterator[])new Iterator[]{this.binaryTransform.values().iterator(), this.binaryClassification.values().iterator(), this.binaryINd4j.values().iterator(), this.unaryClassification.values().iterator(), this.featureVectors.values().iterator()});
    }

    double scaleAndRegularize(MultiDimensionalMap<String, String, INDArray> derivatives, MultiDimensionalMap<String, String, INDArray> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (MultiDimensionalMap.Entry entry : currentMatrices.entrySet()) {
            INDArray D = (INDArray)derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = Nd4j.getBlasWrapper().scal(scale, D).add(Nd4j.getBlasWrapper().scal(regCost, (INDArray)entry.getValue()));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), (Object)D);
            cost += (Double)((INDArray)entry.getValue()).mul((INDArray)entry.getValue()).sum(Integer.MAX_VALUE).element() * regCost / 2.0;
        }
        return cost;
    }

    double scaleAndRegularize(Map<String, INDArray> derivatives, Map<String, INDArray> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (Map.Entry<String, INDArray> entry : currentMatrices.entrySet()) {
            INDArray D = derivatives.get(entry.getKey());
            D = Nd4j.getBlasWrapper().scal(scale, D).add(Nd4j.getBlasWrapper().scal(regCost, entry.getValue()));
            derivatives.put(entry.getKey(), D);
            cost += (Double)entry.getValue().mul(entry.getValue()).sum(Integer.MAX_VALUE).element() * regCost / 2.0;
        }
        return cost;
    }

    double scaleAndRegularizeINDArray(MultiDimensionalMap<String, String, INDArray> derivatives, MultiDimensionalMap<String, String, INDArray> currentMatrices, double scale, double regCost) {
        double cost = 0.0;
        for (MultiDimensionalMap.Entry entry : currentMatrices.entrySet()) {
            INDArray D = (INDArray)derivatives.get(entry.getFirstKey(), entry.getSecondKey());
            D = D.muli((Number)scale).add(((INDArray)entry.getValue()).muli((Number)regCost));
            derivatives.put(entry.getFirstKey(), entry.getSecondKey(), (Object)D);
            cost += (Double)((INDArray)entry.getValue()).mul((INDArray)entry.getValue()).sum(Integer.MAX_VALUE).element() * regCost / 2.0;
        }
        return cost;
    }

    private void backpropDerivativesAndError(Tree tree, MultiDimensionalMap<String, String, INDArray> binaryTD, MultiDimensionalMap<String, String, INDArray> binaryCD, MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD, Map<String, INDArray> unaryCD, Map<String, INDArray> wordVectorD) {
        INDArray delta = Nd4j.create((int)this.numHidden, (int)1);
        this.backpropDerivativesAndError(tree, binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, delta);
    }

    private void backpropDerivativesAndError(Tree tree, MultiDimensionalMap<String, String, INDArray> binaryTD, MultiDimensionalMap<String, String, INDArray> binaryCD, MultiDimensionalMap<String, String, INDArray> binaryINDArrayTD, Map<String, INDArray> unaryCD, Map<String, INDArray> wordVectorD, INDArray deltaUp) {
        Float nodeWeight;
        if (tree.isLeaf()) {
            return;
        }
        INDArray currentVector = tree.vector();
        String category = tree.label();
        category = this.basicCategory(category);
        INDArray goldLabel = Nd4j.create((int)this.numOuts, (int)1);
        int goldClass = tree.goldLabel();
        if (goldClass >= 0) {
            assert (goldClass <= this.numOuts) : "Tried adding a label that was >= to the number of configured outputs " + this.numOuts + " with label " + goldClass;
            goldLabel.putScalar(goldClass, 1.0f);
        }
        if ((nodeWeight = this.classWeights.get(goldClass)) == null) {
            nodeWeight = Float.valueOf(1.0f);
        }
        INDArray predictions = tree.prediction();
        INDArray deltaClass = goldClass >= 0 ? Nd4j.getBlasWrapper().scal(nodeWeight.floatValue(), predictions.sub(goldLabel)) : Nd4j.create((int)predictions.rows(), (int)predictions.columns());
        INDArray localCD = deltaClass.mmul(Nd4j.appendBias((INDArray[])new INDArray[]{currentVector}).transpose());
        double error = -((Double)Transforms.log((INDArray)predictions).muli(goldLabel).sum(Integer.MAX_VALUE).element()).doubleValue();
        tree.setError(error *= (double)nodeWeight.floatValue());
        if (tree.isPreTerminal()) {
            unaryCD.put(category, unaryCD.get(category).add(localCD));
            String word = tree.children().get(0).label();
            word = this.getVocabWord(word);
            INDArray currentVectorDerivative = (INDArray)this.activationFunction.apply((Object)currentVector);
            INDArray deltaFromClass = this.getUnaryClassification(category).transpose().mmul(deltaClass);
            deltaFromClass = deltaFromClass.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.numHidden), NDArrayIndex.interval((int)0, (int)1)}).mul(currentVectorDerivative);
            INDArray deltaFull = deltaFromClass.add(deltaUp);
            wordVectorD.put(word, wordVectorD.get(word).add(deltaFull));
        } else {
            INDArray deltaDown;
            String leftCategory = this.basicCategory(tree.children().get(0).label());
            String rightCategory = this.basicCategory(tree.children().get(1).label());
            if (this.combineClassification) {
                unaryCD.put("", unaryCD.get("").add(localCD));
            } else {
                binaryCD.put((Object)leftCategory, (Object)rightCategory, (Object)((INDArray)binaryCD.get((Object)leftCategory, (Object)rightCategory)).add(localCD));
            }
            INDArray currentVectorDerivative = this.activationFunction.applyDerivative(currentVector);
            INDArray deltaFromClass = this.getBinaryClassification(leftCategory, rightCategory).transpose().mmul(deltaClass);
            INDArray mult = deltaFromClass.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)this.numHidden), NDArrayIndex.interval((int)0, (int)1)});
            deltaFromClass = mult.muli(currentVectorDerivative);
            INDArray deltaFull = deltaFromClass.add(deltaUp);
            INDArray leftVector = tree.children().get(0).vector();
            INDArray rightVector = tree.children().get(1).vector();
            INDArray childrenVector = Nd4j.appendBias((INDArray[])new INDArray[]{leftVector, rightVector});
            INDArray add = (INDArray)binaryTD.get((Object)leftCategory, (Object)rightCategory);
            INDArray W_df = deltaFromClass.mmul(childrenVector.transpose());
            binaryTD.put((Object)leftCategory, (Object)rightCategory, (Object)add.add(W_df));
            if (this.useFloatTensors) {
                INDArray Wt_df = this.getINDArrayGradient(deltaFull, leftVector, rightVector);
                binaryINDArrayTD.put((Object)leftCategory, (Object)rightCategory, (Object)((INDArray)binaryINDArrayTD.get((Object)leftCategory, (Object)rightCategory)).add(Wt_df));
                deltaDown = this.computeINDArrayDeltaDown(deltaFull, leftVector, rightVector, this.getBinaryTransform(leftCategory, rightCategory), this.getBinaryINDArray(leftCategory, rightCategory));
            } else {
                deltaDown = this.getBinaryTransform(leftCategory, rightCategory).transpose().mmul(deltaFull);
            }
            INDArray leftDerivative = (INDArray)this.activationFunction.apply((Object)leftVector);
            INDArray rightDerivative = (INDArray)this.activationFunction.apply((Object)rightVector);
            INDArray leftDeltaDown = deltaDown.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)deltaFull.rows()), NDArrayIndex.interval((int)0, (int)1)});
            INDArray rightDeltaDown = deltaDown.get(new NDArrayIndex[]{NDArrayIndex.interval((int)deltaFull.rows(), (int)(deltaFull.rows() * 2)), NDArrayIndex.interval((int)0, (int)1)});
            this.backpropDerivativesAndError(tree.children().get(0), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, leftDerivative.mul(leftDeltaDown));
            this.backpropDerivativesAndError(tree.children().get(1), binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD, rightDerivative.mul(rightDeltaDown));
        }
    }

    private INDArray computeINDArrayDeltaDown(INDArray deltaFull, INDArray leftVector, INDArray rightVector, INDArray W, INDArray Wt) {
        INDArray WTDelta = W.transpose().mmul(deltaFull);
        INDArray WTDeltaNoBias = WTDelta.isMatrix() ? WTDelta.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)1), NDArrayIndex.interval((int)0, (int)(deltaFull.rows() * 2 + 1))}) : WTDelta.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(deltaFull.rows() * 2))});
        int size = deltaFull.length();
        INDArray deltaINDArray = Nd4j.create((int)(size * 2), (int)1);
        INDArray fullVector = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
        for (int slice = 0; slice < size; ++slice) {
            INDArray scaledFullVector = Nd4j.getBlasWrapper().scal(((Double)deltaFull.getScalar(slice).element()).doubleValue(), fullVector);
            deltaINDArray = deltaINDArray.add(Wt.slice(slice).add(Wt.slice(slice).transpose()).mmul(scaledFullVector));
        }
        return deltaINDArray.add(WTDeltaNoBias);
    }

    public void forwardPropagateTree(Tree tree) {
        INDArray nodeVector;
        INDArray classification;
        if (tree.isLeaf()) {
            throw new AssertionError((Object)"We should not have reached leaves in forwardPropagate");
        }
        if (tree.isPreTerminal()) {
            classification = this.getUnaryClassification(tree.label());
            String word = tree.children().get(0).value();
            INDArray wordVector = this.getFeatureVector(word);
            if (wordVector == null) {
                wordVector = this.featureVectors.get(UNKNOWN_FEATURE);
            }
            nodeVector = (INDArray)this.activationFunction.apply((Object)wordVector);
        } else {
            if (tree.children().size() == 1) {
                throw new AssertionError((Object)"Non-preterminal nodes of size 1 should have already been collapsed");
            }
            if (tree.children().size() == 2) {
                Tree left = tree.firstChild();
                Tree right = tree.lastChild();
                this.forwardPropagateTree(left);
                this.forwardPropagateTree(right);
                String leftCategory = tree.children().get(0).label();
                String rightCategory = tree.children().get(1).label();
                INDArray W = this.getBinaryTransform(leftCategory, rightCategory);
                classification = this.getBinaryClassification(leftCategory, rightCategory);
                INDArray leftVector = tree.children().get(0).vector();
                INDArray rightVector = tree.children().get(1).vector();
                INDArray childrenVector = Nd4j.appendBias((INDArray[])new INDArray[]{leftVector, rightVector});
                if (this.useFloatTensors) {
                    INDArray doubleT = this.getBinaryINDArray(leftCategory, rightCategory);
                    INDArray INDArrayIn = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
                    INDArray INDArrayOut = Nd4j.bilinearProducts((INDArray)doubleT, (INDArray)INDArrayIn);
                    nodeVector = (INDArray)this.activationFunction.apply((Object)W.mmul(childrenVector).add(INDArrayOut));
                } else {
                    nodeVector = (INDArray)this.activationFunction.apply((Object)W.mmul(childrenVector));
                }
            } else {
                throw new AssertionError((Object)"Tree not correctly binarized");
            }
        }
        INDArray inputWithBias = Nd4j.appendBias((INDArray[])new INDArray[]{nodeVector});
        if (inputWithBias.rows() != classification.columns()) {
            inputWithBias = inputWithBias.transpose();
        }
        INDArray preAct = classification.mmul(inputWithBias);
        INDArray predictions = (INDArray)this.outputActivation.apply((Object)preAct);
        tree.setPrediction(predictions);
        tree.setVector(nodeVector);
    }

    private INDArray getFloatTensorGradient(INDArray deltaFull, INDArray leftVector, INDArray rightVector) {
        int size = deltaFull.length();
        INDArray Wt_df = Nd4j.create((int[])new int[]{size * 2, size * 2, size});
        INDArray fullVector = Nd4j.concat((int)0, (INDArray[])new INDArray[]{leftVector, rightVector});
        for (int slice = 0; slice < size; ++slice) {
            Wt_df.putSlice(slice, Nd4j.getBlasWrapper().scal(deltaFull.getDouble(slice), fullVector).mmul(fullVector.transpose()));
        }
        return Wt_df;
    }

    public List<INDArray> output(List<Tree> trees) {
        ArrayList<INDArray> ret = new ArrayList<INDArray>();
        for (Tree t : trees) {
            this.forwardPropagateTree(t);
            ret.add(t.prediction());
        }
        return ret;
    }

    public List<Integer> predict(List<Tree> trees) {
        ArrayList<Integer> ret = new ArrayList<Integer>();
        for (Tree t : trees) {
            this.forwardPropagateTree(t);
            ret.add(Nd4j.getBlasWrapper().iamax(t.prediction()));
        }
        return ret;
    }

    public void setParameters(INDArray params) {
        this.setParams(params, this.binaryTransform.values().iterator(), this.binaryClassification.values().iterator(), this.binaryINd4j.values().iterator(), this.unaryClassification.values().iterator(), this.featureVectors.values().iterator());
    }

    public INDArray getValueGradient(int iterations) {
        int numCols;
        int numRows;
        final MultiDimensionalMap binaryTD = MultiDimensionalMap.newTreeBackedMap();
        final MultiDimensionalMap binaryINDArrayTD = MultiDimensionalMap.newTreeBackedMap();
        final MultiDimensionalMap binaryCD = MultiDimensionalMap.newTreeBackedMap();
        final TreeMap<String, INDArray> unaryCD = new TreeMap<String, INDArray>();
        final TreeMap<String, INDArray> wordVectorD = new TreeMap<String, INDArray>();
        for (MultiDimensionalMap.Entry entry : this.binaryTransform.entrySet()) {
            numRows = ((INDArray)entry.getValue()).rows();
            numCols = ((INDArray)entry.getValue()).columns();
            binaryTD.put(entry.getFirstKey(), entry.getSecondKey(), (Object)Nd4j.create((int)numRows, (int)numCols));
        }
        if (!this.combineClassification) {
            for (MultiDimensionalMap.Entry entry : this.binaryClassification.entrySet()) {
                numRows = ((INDArray)entry.getValue()).rows();
                numCols = ((INDArray)entry.getValue()).columns();
                binaryCD.put(entry.getFirstKey(), entry.getSecondKey(), (Object)Nd4j.create((int)numRows, (int)numCols));
            }
        }
        if (this.useFloatTensors) {
            for (MultiDimensionalMap.Entry entry : this.binaryINd4j.entrySet()) {
                numRows = ((INDArray)entry.getValue()).size(1);
                numCols = ((INDArray)entry.getValue()).size(2);
                int numSlices = ((INDArray)entry.getValue()).slices();
                binaryINDArrayTD.put(entry.getFirstKey(), entry.getSecondKey(), (Object)Nd4j.create((int[])new int[]{numRows, numCols, numSlices}));
            }
        }
        for (Map.Entry entry : this.unaryClassification.entrySet()) {
            numRows = ((INDArray)entry.getValue()).rows();
            numCols = ((INDArray)entry.getValue()).columns();
            unaryCD.put((String)entry.getKey(), Nd4j.create((int)numRows, (int)numCols));
        }
        for (Map.Entry entry : this.featureVectors.entrySet()) {
            numRows = ((INDArray)entry.getValue()).rows();
            numCols = ((INDArray)entry.getValue()).columns();
            wordVectorD.put((String)entry.getKey(), Nd4j.create((int)numRows, (int)numCols));
        }
        final CopyOnWriteArrayList forwardPropTrees = new CopyOnWriteArrayList();
        Parallelization.iterateInParallel(this.trainingTrees, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<Tree>(){

            public void run(Tree currentItem, Object[] args) {
                Tree trainingTree = new Tree(currentItem);
                trainingTree.connect(new ArrayList<Tree>(currentItem.children()));
                RNTN.this.forwardPropagateTree(trainingTree);
                forwardPropTrees.add(trainingTree);
            }
        }, (ActorSystem)this.rnTnActorSystem);
        final AtomicDouble atomicDouble = new AtomicDouble(0.0);
        Parallelization.iterateInParallel(forwardPropTrees, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<Tree>(){

            public void run(Tree currentItem, Object[] args) {
                RNTN.this.backpropDerivativesAndError(currentItem, (MultiDimensionalMap<String, String, INDArray>)binaryTD, (MultiDimensionalMap<String, String, INDArray>)binaryCD, (MultiDimensionalMap<String, String, INDArray>)binaryINDArrayTD, unaryCD, wordVectorD);
                atomicDouble.addAndGet(currentItem.errorSum());
            }
        }, (Parallelization.RunnableWithParams)new Parallelization.RunnableWithParams<Tree>(){

            public void run(Tree currentItem, Object[] args) {
            }
        }, (ActorSystem)this.rnTnActorSystem, (Object[])new Object[]{binaryTD, binaryCD, binaryINDArrayTD, unaryCD, wordVectorD});
        double scale = 1.0f / (float)this.trainingTrees.size();
        this.value = atomicDouble.doubleValue() * scale;
        this.value += this.scaleAndRegularize((MultiDimensionalMap<String, String, INDArray>)binaryTD, this.binaryTransform, scale, this.regTransformMatrix);
        this.value += this.scaleAndRegularize((MultiDimensionalMap<String, String, INDArray>)binaryCD, this.binaryClassification, scale, this.regClassification);
        this.value += this.scaleAndRegularizeINDArray((MultiDimensionalMap<String, String, INDArray>)binaryINDArrayTD, this.binaryINd4j, scale, this.regTransformINDArray);
        this.value += this.scaleAndRegularize(unaryCD, this.unaryClassification, scale, this.regClassification);
        this.value += this.scaleAndRegularize(wordVectorD, this.featureVectors, scale, this.regWordVector);
        INDArray derivative = Nd4j.toFlattened((int)this.getNumParameters(), (Iterator[])new Iterator[]{binaryTD.values().iterator(), binaryCD.values().iterator(), binaryINDArrayTD.values().iterator(), unaryCD.values().iterator(), wordVectorD.values().iterator()});
        if (this.paramAdaGrad == null) {
            this.paramAdaGrad = new AdaGrad(1, derivative.columns());
        }
        derivative.muli(this.paramAdaGrad.getLearningRates(derivative));
        return derivative;
    }

    public double getValue() {
        return this.value;
    }

    public static class Builder {
        private int numHidden;
        private RandomGenerator rng;
        private boolean useINd4j;
        private boolean combineClassification = true;
        private boolean simplifiedModel = true;
        private boolean randomFeatureVectors;
        private double scalingForInit = 0.001f;
        private boolean lowerCasefeatureNames;
        private ActivationFunction activationFunction = Activations.sigmoid();
        private ActivationFunction outputActivationFunction = Activations.softmax();
        private int adagradResetFrequency;
        private double regTransformINDArray;
        private Map<String, INDArray> featureVectors;
        private int numBinaryMatrices;
        private int binaryTransformSize;
        private int binaryINd4jize;
        private int binaryClassificationSize;
        private int numUnaryMatrices;
        private int unaryClassificationSize;
        private Map<Integer, Float> classWeights;

        public Builder withOutputActivation(ActivationFunction outputActivationFunction) {
            this.outputActivationFunction = outputActivationFunction;
            return this;
        }

        public Builder setFeatureVectors(Word2Vec vec) {
            this.setFeatureVectors(vec);
            this.numHidden = vec.getLayerSize();
            return this;
        }

        public Builder setNumHidden(int numHidden) {
            this.numHidden = numHidden;
            return this;
        }

        public Builder setRng(RandomGenerator rng) {
            this.rng = rng;
            return this;
        }

        public Builder setUseTensors(boolean useINd4j) {
            this.useINd4j = useINd4j;
            return this;
        }

        public Builder setCombineClassification(boolean combineClassification) {
            this.combineClassification = combineClassification;
            return this;
        }

        public Builder setSimplifiedModel(boolean simplifiedModel) {
            this.simplifiedModel = simplifiedModel;
            return this;
        }

        public Builder setRandomFeatureVectors(boolean randomFeatureVectors) {
            this.randomFeatureVectors = randomFeatureVectors;
            return this;
        }

        public Builder setScalingForInit(double scalingForInit) {
            this.scalingForInit = scalingForInit;
            return this;
        }

        public Builder setLowerCasefeatureNames(boolean lowerCasefeatureNames) {
            this.lowerCasefeatureNames = lowerCasefeatureNames;
            return this;
        }

        public Builder setActivationFunction(ActivationFunction activationFunction) {
            this.activationFunction = activationFunction;
            return this;
        }

        public Builder setAdagradResetFrequency(int adagradResetFrequency) {
            this.adagradResetFrequency = adagradResetFrequency;
            return this;
        }

        public Builder setRegTransformINDArray(double regTransformINDArray) {
            this.regTransformINDArray = regTransformINDArray;
            return this;
        }

        public Builder setFeatureVectors(Map<String, INDArray> featureVectors) {
            this.featureVectors = featureVectors;
            return this;
        }

        public Builder setNumBinaryMatrices(int numBinaryMatrices) {
            this.numBinaryMatrices = numBinaryMatrices;
            return this;
        }

        public Builder setBinaryTransformSize(int binaryTransformSize) {
            this.binaryTransformSize = binaryTransformSize;
            return this;
        }

        public Builder setBinaryINd4jize(int binaryINd4jize) {
            this.binaryINd4jize = binaryINd4jize;
            return this;
        }

        public Builder setBinaryClassificationSize(int binaryClassificationSize) {
            this.binaryClassificationSize = binaryClassificationSize;
            return this;
        }

        public Builder setNumUnaryMatrices(int numUnaryMatrices) {
            this.numUnaryMatrices = numUnaryMatrices;
            return this;
        }

        public Builder setUnaryClassificationSize(int unaryClassificationSize) {
            this.unaryClassificationSize = unaryClassificationSize;
            return this;
        }

        public Builder setClassWeights(Map<Integer, Float> classWeights) {
            this.classWeights = classWeights;
            return this;
        }

        public RNTN build() {
            return new RNTN(this.numHidden, this.rng, this.useINd4j, this.combineClassification, this.simplifiedModel, this.randomFeatureVectors, this.scalingForInit, this.lowerCasefeatureNames, this.activationFunction, this.adagradResetFrequency, this.regTransformINDArray, this.featureVectors, this.numBinaryMatrices, this.binaryTransformSize, this.binaryINd4jize, this.binaryClassificationSize, this.numUnaryMatrices, this.unaryClassificationSize, this.classWeights);
        }
    }
}

