/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.preprocessor;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.util.ArrayUtil;

public class FeedForwardToCnnPreProcessor
implements InputPreProcessor {
    private int inputHeight;
    private int inputWidth;
    private int numChannels;
    private int[] shape;

    @JsonCreator
    public FeedForwardToCnnPreProcessor(@JsonProperty(value="inputHeight") int inputHeight, @JsonProperty(value="inputWidth") int inputWidth, @JsonProperty(value="numChannels") int numChannels) {
        this.inputHeight = inputHeight;
        this.inputWidth = inputWidth;
        this.numChannels = numChannels;
    }

    public FeedForwardToCnnPreProcessor(int inputWidth, int inputHeight) {
        this.inputHeight = inputHeight;
        this.inputWidth = inputWidth;
        this.numChannels = 1;
    }

    @Override
    public INDArray preProcess(INDArray input, Layer layer) {
        this.shape = input.shape();
        if (input.shape().length == 4) {
            return input;
        }
        if (input.columns() != this.inputWidth * this.inputHeight * this.numChannels) {
            throw new IllegalArgumentException("Invalid input: expect output columns must be equal to rows " + this.inputHeight + " x columns " + this.inputWidth + " but was instead " + Arrays.toString(input.shape()));
        }
        if (input.ordering() == 'f') {
            input = Shape.toOffsetZeroCopy((INDArray)input, (char)'c');
        }
        return input.reshape(new int[]{input.size(0), this.numChannels, this.inputHeight, this.inputWidth});
    }

    @Override
    public INDArray backprop(INDArray output, Layer layer) {
        if (this.shape == null || ArrayUtil.prod((int[])this.shape) != output.length()) {
            int[] otherOutputs = null;
            if (output.shape().length == 2) {
                return output;
            }
            if (output.shape().length == 4) {
                otherOutputs = new int[3];
            } else if (output.shape().length == 3) {
                otherOutputs = new int[2];
            }
            System.arraycopy(output.shape(), 1, otherOutputs, 0, otherOutputs.length);
            this.shape = new int[]{output.shape()[0], ArrayUtil.prod((int[])otherOutputs)};
        }
        if (output.ordering() == 'f') {
            output = Shape.toOffsetZeroCopy((INDArray)output, (char)'c');
        }
        return output.reshape(this.shape);
    }

    @Override
    public FeedForwardToCnnPreProcessor clone() {
        try {
            FeedForwardToCnnPreProcessor clone = (FeedForwardToCnnPreProcessor)super.clone();
            if (clone.shape != null) {
                clone.shape = (int[])clone.shape.clone();
            }
            return clone;
        }
        catch (CloneNotSupportedException e) {
            throw new RuntimeException(e);
        }
    }

    public int getInputHeight() {
        return this.inputHeight;
    }

    public int getInputWidth() {
        return this.inputWidth;
    }

    public int getNumChannels() {
        return this.numChannels;
    }

    public void setInputHeight(int inputHeight) {
        this.inputHeight = inputHeight;
    }

    public void setInputWidth(int inputWidth) {
        this.inputWidth = inputWidth;
    }

    public void setNumChannels(int numChannels) {
        this.numChannels = numChannels;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof FeedForwardToCnnPreProcessor)) {
            return false;
        }
        FeedForwardToCnnPreProcessor other = (FeedForwardToCnnPreProcessor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getInputHeight() != other.getInputHeight()) {
            return false;
        }
        if (this.getInputWidth() != other.getInputWidth()) {
            return false;
        }
        if (this.getNumChannels() != other.getNumChannels()) {
            return false;
        }
        return Arrays.equals(this.shape, other.shape);
    }

    protected boolean canEqual(Object other) {
        return other instanceof FeedForwardToCnnPreProcessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getInputHeight();
        result = result * 59 + this.getInputWidth();
        result = result * 59 + this.getNumChannels();
        result = result * 59 + Arrays.hashCode(this.shape);
        return result;
    }

    public String toString() {
        return "FeedForwardToCnnPreProcessor(inputHeight=" + this.getInputHeight() + ", inputWidth=" + this.getInputWidth() + ", numChannels=" + this.getNumChannels() + ", shape=" + Arrays.toString(this.shape) + ")";
    }
}

