/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.preprocessor;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.util.ArrayUtil;

public class CnnToFeedForwardPreProcessor
implements InputPreProcessor {
    private int inputHeight;
    private int inputWidth;
    private int numChannels;

    @JsonCreator
    public CnnToFeedForwardPreProcessor(@JsonProperty(value="inputHeight") int inputHeight, @JsonProperty(value="inputWidth") int inputWidth, @JsonProperty(value="numChannels") int numChannels) {
        this.inputHeight = inputHeight;
        this.inputWidth = inputWidth;
        this.numChannels = numChannels;
    }

    public CnnToFeedForwardPreProcessor(int inputHeight, int inputWidth) {
        this.inputHeight = inputHeight;
        this.inputWidth = inputWidth;
        this.numChannels = 1;
    }

    public CnnToFeedForwardPreProcessor() {
    }

    @Override
    public INDArray preProcess(INDArray input, Layer layer) {
        int[] otherOutputs = null;
        this.inputHeight = input.size(-2);
        this.inputWidth = input.size(-1);
        if (input.shape().length == 2) {
            return input;
        }
        if (input.shape().length == 4) {
            this.numChannels = input.size(-3);
            otherOutputs = new int[3];
        } else if (input.shape().length == 3) {
            otherOutputs = new int[2];
        }
        System.arraycopy(input.shape(), 1, otherOutputs, 0, otherOutputs.length);
        int[] shape = new int[]{input.shape()[0], ArrayUtil.prod((int[])otherOutputs)};
        return input.reshape(shape);
    }

    @Override
    public INDArray backprop(INDArray output, Layer layer) {
        if (output.shape().length == 4) {
            return output;
        }
        if (output.columns() != this.inputWidth * this.inputHeight * this.numChannels) {
            throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows " + this.inputHeight + " x columns " + this.inputWidth + " x depth " + this.numChannels + " but was instead " + Arrays.toString(output.shape()));
        }
        return output.reshape(new int[]{output.size(0), this.numChannels, this.inputHeight, this.inputWidth});
    }

    @Override
    public CnnToFeedForwardPreProcessor clone() {
        try {
            CnnToFeedForwardPreProcessor clone = (CnnToFeedForwardPreProcessor)super.clone();
            return clone;
        }
        catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public int getInputHeight() {
        return this.inputHeight;
    }

    public int getInputWidth() {
        return this.inputWidth;
    }

    public int getNumChannels() {
        return this.numChannels;
    }

    public void setInputHeight(int inputHeight) {
        this.inputHeight = inputHeight;
    }

    public void setInputWidth(int inputWidth) {
        this.inputWidth = inputWidth;
    }

    public void setNumChannels(int numChannels) {
        this.numChannels = numChannels;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof CnnToFeedForwardPreProcessor)) {
            return false;
        }
        CnnToFeedForwardPreProcessor other = (CnnToFeedForwardPreProcessor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getInputHeight() != other.getInputHeight()) {
            return false;
        }
        if (this.getInputWidth() != other.getInputWidth()) {
            return false;
        }
        return this.getNumChannels() == other.getNumChannels();
    }

    protected boolean canEqual(Object other) {
        return other instanceof CnnToFeedForwardPreProcessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getInputHeight();
        result = result * 59 + this.getInputWidth();
        result = result * 59 + this.getNumChannels();
        return result;
    }

    public String toString() {
        return "CnnToFeedForwardPreProcessor(inputHeight=" + this.getInputHeight() + ", inputWidth=" + this.getInputWidth() + ", numChannels=" + this.getNumChannels() + ")";
    }
}

