/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.nd4j.linalg.api.ndarray.INDArray;

public class RnnOutputLayer
extends BaseOutputLayer<org.deeplearning4j.nn.conf.layers.RnnOutputLayer> {
    public RnnOutputLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    public RnnOutputLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    private INDArray reshape3dTo2d(INDArray in) {
        if (in.rank() != 3) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 3");
        }
        int[] shape = in.shape();
        if (shape[0] == 1) {
            return in.tensorAlongDimension(0, new int[]{1, 2});
        }
        if (shape[2] == 1) {
            return in.tensorAlongDimension(0, new int[]{1, 0});
        }
        INDArray permuted = in.permute(new int[]{0, 2, 1});
        return permuted.reshape(shape[0] * shape[2], shape[1]);
    }

    private INDArray reshape2dTo3d(INDArray in) {
        if (in.rank() != 2) {
            throw new IllegalArgumentException("Invalid input: expect NDArray with rank 2");
        }
        int[] shape = in.shape();
        int miniBatchSize = this.getInputMiniBatchSize();
        INDArray reshaped = in.reshape(new int[]{miniBatchSize, shape[0] / miniBatchSize, shape[1]});
        return reshaped.permute(new int[]{0, 2, 1});
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        Pair<Gradient, INDArray> gradAndEpsilonNext = super.backpropGradient(epsilon);
        INDArray epsilon2d = gradAndEpsilonNext.getSecond();
        INDArray epsilon3d = this.reshape2dTo3d(epsilon2d);
        return new Pair<Gradient, INDArray>(gradAndEpsilonNext.getFirst(), epsilon3d);
    }

    @Override
    public INDArray output(boolean training) {
        INDArray output2d = super.output(training);
        return this.reshape2dTo3d(output2d);
    }

    @Override
    public double f1Score(INDArray examples, INDArray labels) {
        if (examples.rank() == 3) {
            examples = this.reshape3dTo2d(examples);
        }
        if (labels.rank() == 3) {
            labels = this.reshape3dTo2d(labels);
        }
        return super.f1Score(examples, labels);
    }

    @Override
    public INDArray getInput() {
        return this.input;
    }

    @Override
    public INDArray activate(boolean training) {
        INDArray activations2d = super.activate(training);
        return this.reshape2dTo3d(activations2d);
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        return this.reshape2dTo3d(this.preOutput2d(x, training));
    }

    @Override
    protected INDArray preOutput2d(INDArray input, boolean training) {
        if (input.rank() == 3) {
            input = this.reshape3dTo2d(input);
        }
        return super.preOutput(input, training);
    }

    @Override
    protected INDArray output2d(INDArray input) {
        return this.reshape3dTo2d(this.output(input));
    }

    @Override
    protected INDArray getLabels2d() {
        if (this.labels.rank() == 3) {
            return this.reshape3dTo2d(this.labels);
        }
        return this.labels;
    }
}

