/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.preprocessors;

import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TensorFlowCnnToFeedForwardPreProcessor
extends CnnToFeedForwardPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(TensorFlowCnnToFeedForwardPreProcessor.class);

    @JsonCreator
    public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty(value="inputHeight") int inputHeight, @JsonProperty(value="inputWidth") int inputWidth, @JsonProperty(value="numChannels") int numChannels) {
        super(inputHeight, inputWidth, numChannels);
    }

    public TensorFlowCnnToFeedForwardPreProcessor(int inputHeight, int inputWidth) {
        super(inputHeight, inputWidth);
    }

    public TensorFlowCnnToFeedForwardPreProcessor() {
    }

    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        if (input.rank() == 2) {
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input);
        }
        INDArray permuted = workspaceMgr.dup((Enum)ArrayType.ACTIVATIONS, input.permute(new int[]{0, 2, 3, 1}), 'c');
        int[] inShape = input.shape();
        int[] outShape = new int[]{inShape[0], inShape[1] * inShape[2] * inShape[3]};
        return permuted.reshape('c', outShape);
    }

    public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        if (epsilons.ordering() != 'c' || !Shape.hasDefaultStridesForShape((INDArray)epsilons)) {
            epsilons = workspaceMgr.dup((Enum)ArrayType.ACTIVATION_GRAD, epsilons, 'c');
        }
        INDArray epsilonsReshaped = epsilons.reshape('c', new int[]{epsilons.size(0), this.inputHeight, this.inputWidth, this.numChannels});
        return epsilonsReshaped.permute(new int[]{0, 3, 1, 2});
    }

    public TensorFlowCnnToFeedForwardPreProcessor clone() {
        return (TensorFlowCnnToFeedForwardPreProcessor)super.clone();
    }
}

