/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.preprocessor;

import com.fasterxml.jackson.annotation.JsonProperty;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;

public class RnnToCnnPreProcessor
implements InputPreProcessor {
    private int inputHeight;
    private int inputWidth;
    private int numChannels;
    private int product;

    public RnnToCnnPreProcessor(@JsonProperty(value="inputHeight") int inputHeight, @JsonProperty(value="inputWidth") int inputWidth, @JsonProperty(value="numChannels") int numChannels) {
        this.inputHeight = inputHeight;
        this.inputWidth = inputWidth;
        this.numChannels = numChannels;
        this.product = inputHeight * inputWidth * numChannels;
    }

    @Override
    public INDArray preProcess(INDArray input, Layer layer) {
        INDArray in2d;
        int[] shape = input.shape();
        if (shape[0] == 1) {
            in2d = input.tensorAlongDimension(0, new int[]{1, 2});
        } else if (shape[2] == 1) {
            in2d = input.tensorAlongDimension(0, new int[]{1, 0});
        } else {
            INDArray permuted = input.permute(new int[]{0, 2, 1});
            in2d = permuted.reshape(shape[0] * shape[2], shape[1]);
        }
        return in2d.reshape(new int[]{shape[0] * shape[2], this.numChannels, this.inputHeight, this.inputWidth});
    }

    @Override
    public INDArray backprop(INDArray output, Layer layer) {
        if (output.ordering() == 'f') {
            output = Shape.toOffsetZeroCopy((INDArray)output, (char)'c');
        }
        int[] shape = output.shape();
        int miniBatchSize = layer.getInputMiniBatchSize();
        INDArray reshaped = output.reshape(new int[]{miniBatchSize, shape[0] / miniBatchSize, this.product});
        return reshaped.permute(new int[]{0, 2, 1});
    }

    @Override
    public RnnToCnnPreProcessor clone() {
        return new RnnToCnnPreProcessor(this.inputHeight, this.inputWidth, this.numChannels);
    }
}

