/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.preprocessors;

import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasFlattenRnnPreprocessor
extends BaseInputPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(KerasFlattenRnnPreprocessor.class);
    private int tsLength;
    private int depth;

    public KerasFlattenRnnPreprocessor(int depth, int tsLength) {
        this.tsLength = Math.abs(tsLength);
        this.depth = depth;
    }

    public INDArray preProcess(INDArray input, int miniBatchSize) {
        INDArray output = input.dup('c');
        output.reshape(input.size(0), this.depth * this.tsLength);
        return output;
    }

    public INDArray backprop(INDArray epsilons, int miniBatchSize) {
        return epsilons.dup().reshape(new int[]{miniBatchSize, this.depth, this.tsLength});
    }

    public KerasFlattenRnnPreprocessor clone() {
        return (KerasFlattenRnnPreprocessor)super.clone();
    }

    public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
        return InputType.feedForward((int)(this.depth * this.tsLength));
    }
}

