/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.preprocessors;

import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasFlattenRnnPreprocessor
extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(KerasFlattenRnnPreprocessor.class);
    private long tsLength;
    private long depth;

    public KerasFlattenRnnPreprocessor(long depth, long tsLength) {
        this.tsLength = Math.abs(tsLength);
        this.depth = depth;
    }

    public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        INDArray output = workspaceMgr.dup((Enum)ArrayType.ACTIVATIONS, input, 'c');
        return output.reshape(input.size(0), this.depth * this.tsLength);
    }

    public INDArray backprop(INDArray epsilons, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
        return workspaceMgr.dup((Enum)ArrayType.ACTIVATION_GRAD, epsilons, 'c').reshape(new long[]{miniBatchSize, this.depth, this.tsLength});
    }

    public KerasFlattenRnnPreprocessor clone() {
        return (KerasFlattenRnnPreprocessor)super.clone();
    }

    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        return InputType.feedForward((long)(this.depth * this.tsLength));
    }
}

