/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.featuredetectors.autoencoder;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.models.featuredetectors.autoencoder.AutoEncoder;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.OutputLayer;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SemanticHashing
extends BaseMultiLayerNetwork {
    private static final long serialVersionUID = -3571832097247806784L;
    private BaseMultiLayerNetwork encoder;
    private static Logger log = LoggerFactory.getLogger(SemanticHashing.class);

    @Override
    public void pretrain(INDArray input, Object[] otherParams) {
        throw new IllegalStateException("Not implemented");
    }

    @Override
    public NeuralNetwork createLayer(INDArray input, INDArray W, INDArray hbias, INDArray vBias, int index) {
        throw new IllegalStateException("Not implemented");
    }

    @Override
    public List<INDArray> computeDeltasR(INDArray v) {
        int i;
        ArrayList<INDArray> deltaRet = new ArrayList<INDArray>();
        INDArray[] deltas = new INDArray[this.getnLayers() + 1];
        List<INDArray> activations = this.feedForward();
        List<INDArray> rActivations = this.feedForwardR(activations, v);
        ArrayList<INDArray> weights = new ArrayList<INDArray>();
        ArrayList<INDArray> biases = new ArrayList<INDArray>();
        ArrayList<ActivationFunction> activationFunctions = new ArrayList<ActivationFunction>();
        for (int j = 0; j < this.getNeuralNets().length; ++j) {
            weights.add(this.getNeuralNets()[j].getW());
            biases.add(this.getNeuralNets()[j].gethBias());
            AutoEncoder a = (AutoEncoder)this.getNeuralNets()[j];
            activationFunctions.add(a.conf().getActivationFunction());
        }
        weights.add(this.getOutputLayer().getW());
        biases.add(this.getOutputLayer().getB());
        activationFunctions.add(this.getOutputLayer().conf().getActivationFunction());
        INDArray rix = rActivations.get(rActivations.size() - 1).div((Number)this.input.rows());
        for (i = this.getnLayers(); i >= 0; --i) {
            deltas[i] = activations.get(i).transpose().mmul(rix);
            this.applyDropConnectIfNecessary(deltas[i]);
            if (i <= 0) continue;
            rix = rix.mmul(((INDArray)weights.get(i)).addRowVector((INDArray)biases.get(i)).transpose()).muli(((ActivationFunction)activationFunctions.get(i - 1)).applyDerivative(activations.get(i)));
        }
        for (i = 0; i < deltas.length; ++i) {
            if (((NeuralNetConfiguration)this.layerWiseConfigurations.get(i)).isConstrainGradientToUnitNorm()) {
                deltaRet.add(deltas[i].div(deltas[i].norm2(Integer.MAX_VALUE)));
                continue;
            }
            deltaRet.add(deltas[i]);
        }
        return deltaRet;
    }

    @Override
    public List<INDArray> feedForwardR(List<INDArray> acts, INDArray v) {
        ArrayList<INDArray> R = new ArrayList<INDArray>();
        R.add(Nd4j.zeros((int)this.input.rows(), (int)this.input.columns()));
        List<Pair<INDArray, INDArray>> vWvB = this.unPack(v);
        List<INDArray> W = this.weightMatrices();
        for (int i = 0; i < this.neuralNets.length; ++i) {
            AutoEncoder a = (AutoEncoder)this.getNeuralNets()[i];
            ActivationFunction derivative = a.conf().getActivationFunction();
            R.add(((INDArray)R.get(i)).mmul(W.get(i)).add(acts.get(i).mmul(vWvB.get(i).getFirst().addRowVector(vWvB.get(i).getSecond())).add((Number)1)).mul(derivative.applyDerivative(acts.get(i + 1))));
        }
        R.add(((INDArray)R.get(R.size() - 1)).mmul(W.get(W.size() - 1)).add(acts.get(acts.size() - 2).mmul(vWvB.get(vWvB.size() - 1).getFirst().addRowVector(vWvB.get(vWvB.size() - 1).getSecond()))).mul(this.getOutputLayer().conf().getActivationFunction().applyDerivative(acts.get(acts.size() - 1))));
        return R;
    }

    @Override
    public void pretrain(DataSetIterator iter, Object[] otherParams) {
        throw new UnsupportedOperationException("Not implemented");
    }

    @Override
    public NeuralNetwork[] createNetworkLayers(int numLayers) {
        return new NeuralNetwork[numLayers];
    }

    @Override
    public void finetune(INDArray input) {
        this.input = input;
        this.setInput(input);
        this.setLabels(input);
        super.finetune(input);
    }

    @Override
    public List<Pair<INDArray, INDArray>> computeDeltas2() {
        int i;
        ArrayList<Pair<INDArray, INDArray>> deltaRet = new ArrayList<Pair<INDArray, INDArray>>();
        List<INDArray> activations = this.feedForward();
        INDArray[] deltas = new INDArray[activations.size() - 1];
        INDArray[] preCons = new INDArray[activations.size() - 1];
        INDArray ix = activations.get(activations.size() - 1).sub(this.labels).divi((Number)this.labels.rows());
        ArrayList<INDArray> weights = new ArrayList<INDArray>();
        ArrayList<INDArray> biases = new ArrayList<INDArray>();
        ArrayList<ActivationFunction> activationFunctions = new ArrayList<ActivationFunction>();
        for (int j = 0; j < this.getNeuralNets().length; ++j) {
            weights.add(this.getNeuralNets()[j].getW());
            biases.add(this.getNeuralNets()[j].gethBias());
            AutoEncoder a = (AutoEncoder)this.getNeuralNets()[j];
            activationFunctions.add(a.conf().getActivationFunction());
        }
        biases.add(this.getOutputLayer().getB());
        weights.add(this.getOutputLayer().getW());
        activationFunctions.add(this.getOutputLayer().conf().getActivationFunction());
        for (i = weights.size() - 1; i >= 0; --i) {
            deltas[i] = activations.get(i).transpose().mmul(ix);
            preCons[i] = Transforms.pow((INDArray)activations.get(i).transpose(), (Number)2).mmul(Transforms.pow((INDArray)ix.dup(), (Number)2)).muli((Number)this.labels.rows());
            this.applyDropConnectIfNecessary(deltas[i]);
            if (i <= 0) continue;
            ix = ix.mmul(((INDArray)weights.get(i)).transpose()).muli(((ActivationFunction)activationFunctions.get(i - 1)).applyDerivative(activations.get(i)));
        }
        for (i = 0; i < deltas.length; ++i) {
            if (((NeuralNetConfiguration)this.layerWiseConfigurations.get(i)).isConstrainGradientToUnitNorm()) {
                deltaRet.add(new Pair<INDArray, INDArray>(deltas[i].divi(deltas[i].norm2(Integer.MAX_VALUE)), preCons[i]));
                continue;
            }
            deltaRet.add(new Pair<INDArray, INDArray>(deltas[i], preCons[i]));
        }
        return deltaRet;
    }

    public BaseMultiLayerNetwork getEncoder() {
        return this.encoder;
    }

    public void setEncoder(BaseMultiLayerNetwork encoder) {
        this.encoder = encoder;
    }

    @Override
    public double score(org.nd4j.linalg.dataset.api.DataSet data) {
        return 0.0;
    }

    @Override
    public int numLabels() {
        return 0;
    }

    @Override
    public INDArray labelProbabilities(INDArray examples) {
        return null;
    }

    @Override
    public void fit(INDArray examples, INDArray labels) {
        this.fit((org.nd4j.linalg.dataset.api.DataSet)new DataSet(examples, labels));
    }

    @Override
    public void fit(org.nd4j.linalg.dataset.api.DataSet data) {
        this.input = data.getFeatureMatrix();
        this.finetune(data.getLabels());
    }

    @Override
    public void fit(INDArray examples, int[] labels) {
    }

    @Override
    public void fit(INDArray data, Object[] params) {
    }

    public static class Builder
    extends BaseMultiLayerNetwork.Builder<SemanticHashing> {
        private BaseMultiLayerNetwork encoder;

        public Builder() {
            this.clazz = SemanticHashing.class;
        }

        public Builder withEncoder(BaseMultiLayerNetwork encoder) {
            this.encoder = encoder;
            return this;
        }

        public Builder useGaussNewtonVectorProductBackProp(boolean useGaussNewtonVectorProductBackProp) {
            super.useGaussNewtonVectorProductBackProp(useGaussNewtonVectorProductBackProp);
            return this;
        }

        public Builder useDropConnection(boolean useDropConnect) {
            super.useDropConnection(useDropConnect);
            return this;
        }

        public Builder withVisibleBiasTransforms(Map<Integer, MatrixTransform> visibleBiasTransforms) {
            super.withVisibleBiasTransforms(visibleBiasTransforms);
            return this;
        }

        public Builder withHiddenBiasTransforms(Map<Integer, MatrixTransform> hiddenBiasTransforms) {
            super.withHiddenBiasTransforms(hiddenBiasTransforms);
            return this;
        }

        public Builder forceIterations() {
            super.forceIterations();
            return this;
        }

        public Builder disableBackProp() {
            super.disableBackProp();
            return this;
        }

        public Builder transformWeightsAt(int layer, MatrixTransform transform) {
            super.transformWeightsAt(layer, transform);
            return this;
        }

        public Builder transformWeightsAt(Map<Integer, MatrixTransform> transforms) {
            super.transformWeightsAt(transforms);
            return this;
        }

        public Builder hiddenLayerSizes(Integer[] hiddenLayerSizes) {
            super.hiddenLayerSizes(hiddenLayerSizes);
            return this;
        }

        public Builder hiddenLayerSizes(int[] hiddenLayerSizes) {
            super.hiddenLayerSizes(hiddenLayerSizes);
            return this;
        }

        public Builder layerWiseConfiguration(List<NeuralNetConfiguration> layerWiseConfiguration) {
            super.layerWiseConfiguration(layerWiseConfiguration);
            return this;
        }

        public Builder withInput(INDArray input) {
            super.withInput(input);
            return this;
        }

        public Builder withLabels(INDArray labels) {
            super.withLabels(labels);
            return this;
        }

        public Builder withClazz(Class<? extends BaseMultiLayerNetwork> clazz) {
            super.withClazz(clazz);
            return this;
        }

        @Override
        public SemanticHashing buildEmpty() {
            return (SemanticHashing)super.buildEmpty();
        }

        @Override
        public SemanticHashing build() {
            int inverseCount = this.encoder.getNeuralNets().length - 1;
            NeuralNetwork[] autoEncoders = new NeuralNetwork[this.encoder.getNeuralNets().length * 2 - 1];
            Layer[] hiddenLayers = new Layer[autoEncoders.length + 1];
            for (int i = 0; i < autoEncoders.length; ++i) {
                if (i < this.encoder.getNeuralNets().length) {
                    AutoEncoder a = (AutoEncoder)new AutoEncoder.Builder().configure(this.encoder.getNeuralNets()[i].conf().clone()).withVisibleBias(this.encoder.getNeuralNets()[i].getvBias().dup()).withHBias(this.encoder.getNeuralNets()[i].gethBias().dup()).build();
                    int nIn = a.getW().rows();
                    int nOut = a.getW().columns();
                    Layer h = this.encoder.getLayers()[i].clone();
                    h.setConfiguration(a.conf());
                    hiddenLayers[i] = h;
                    autoEncoders[i] = a;
                    hiddenLayers[i].setB(a.gethBias());
                    hiddenLayers[i].setW(a.getW());
                    hiddenLayers[i].conf().setnIn(nIn);
                    hiddenLayers[i].conf().setnOut(nOut);
                    autoEncoders[i].conf().setnIn(nIn);
                    autoEncoders[i].conf().setnOut(nOut);
                    if (i != this.encoder.getNeuralNets().length - 1) continue;
                    a.conf().setActivationFunction(Activations.linear());
                    continue;
                }
                NeuralNetConfiguration reverseConf = this.encoder.getNeuralNets()[inverseCount].conf().clone();
                AutoEncoder a = (AutoEncoder)new AutoEncoder.Builder().configure(reverseConf).withWeights(this.encoder.getNeuralNets()[inverseCount].getW().transpose()).withVisibleBias(this.encoder.getNeuralNets()[inverseCount].gethBias().dup()).withHBias(this.encoder.getNeuralNets()[inverseCount].getvBias().dup()).build();
                int nIn = a.getW().rows();
                int nOut = a.getW().columns();
                reverseConf.setnIn(nIn);
                reverseConf.setnOut(nOut);
                autoEncoders[i] = a;
                hiddenLayers[i] = this.encoder.getLayers()[inverseCount].transpose();
                hiddenLayers[i].setConfiguration(reverseConf);
                hiddenLayers[i].setB(a.gethBias());
                hiddenLayers[i].setW(a.getW());
                --inverseCount;
            }
            OutputLayer o = new OutputLayer.Builder().configure(this.encoder.getNeuralNets()[0].conf()).withBias(this.encoder.getNeuralNets()[0].getvBias()).withWeights(this.encoder.getNeuralNets()[0].getW().transpose()).build();
            o.conf().setLossFunction(this.encoder.getOutputLayer().conf().getLossFunction());
            o.conf().setActivationType(NeuralNetConfiguration.ActivationType.HIDDEN_LAYER_ACTIVATION);
            o.conf().setnIn(o.getW().rows());
            o.conf().setnOut(o.getW().columns());
            hiddenLayers[hiddenLayers.length - 1] = o;
            SemanticHashing e = new SemanticHashing();
            e.setLayers(hiddenLayers);
            e.setNeuralNets(autoEncoders);
            e.setDefaultConfiguration(this.conf);
            e.setUseDropConnect(this.encoder.isUseDropConnect());
            e.setUseGaussNewtonVectorProductBackProp(this.encoder.isUseGaussNewtonVectorProductBackProp());
            e.setSampleFromHiddenActivations(this.encoder.isSampleFromHiddenActivations());
            e.setForceNumEpochs(this.shouldForceEpochs);
            ArrayList<NeuralNetConfiguration> confs = new ArrayList<NeuralNetConfiguration>();
            for (int i = 0; i < e.layers.length; ++i) {
                confs.add(e.layers[i].conf());
            }
            e.setLayerWiseConfigurations(confs);
            e.setDefaultConfiguration((NeuralNetConfiguration)confs.get(0));
            e.dimensionCheck();
            return e;
        }
    }
}

