/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.layers.samediff.SDLayerParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayerUtils;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.autodiff.samediff.SDIndex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties(value={"paramShapes"})
public class LocallyConnected2D
extends SameDiffLayer {
    private static final List<String> WEIGHT_KEYS = Collections.singletonList("W");
    private static final List<String> BIAS_KEYS = Collections.singletonList("b");
    private static final List<String> PARAM_KEYS = Arrays.asList("b", "W");
    private long nIn;
    private long nOut;
    private Activation activation;
    private int[] kernel;
    private int[] stride;
    private int[] padding;
    private ConvolutionMode cm;
    private int[] dilation;
    private boolean hasBias;
    private int[] inputSize;
    private int[] outputSize;
    private int featureDim;

    protected LocallyConnected2D(Builder builder) {
        super(builder);
        this.nIn = builder.nIn;
        this.nOut = builder.nOut;
        this.activation = builder.activation;
        this.kernel = builder.kernel;
        this.stride = builder.stride;
        this.padding = builder.padding;
        this.cm = builder.cm;
        this.dilation = builder.dilation;
        this.hasBias = builder.hasBias;
        this.inputSize = builder.inputSize;
        this.featureDim = this.kernel[0] * this.kernel[1] * (int)this.nIn;
    }

    private LocallyConnected2D() {
    }

    public void computeOutputSize() {
        int nIn = (int)this.getNIn();
        if (this.inputSize == null) {
            throw new IllegalArgumentException("Input size has to be specified for locally connected layers.");
        }
        int[] inputShape = new int[]{1, nIn, this.inputSize[0], this.inputSize[1]};
        INDArray dummyInputForShapeInference = Nd4j.ones((int[])inputShape);
        if (this.cm == ConvolutionMode.Same) {
            this.outputSize = ConvolutionUtils.getOutputSize(dummyInputForShapeInference, this.kernel, this.stride, null, this.cm, this.dilation);
            this.padding = ConvolutionUtils.getSameModeTopLeftPadding(this.outputSize, this.inputSize, this.kernel, this.stride, this.dilation);
        } else {
            this.outputSize = ConvolutionUtils.getOutputSize(dummyInputForShapeInference, this.kernel, this.stride, this.padding, this.cm, this.dilation);
        }
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN) {
            throw new IllegalArgumentException("Provided input type for locally connected 2D layers has to be of CNN type, got: " + inputType);
        }
        InputType.InputTypeConvolutional cnnType = (InputType.InputTypeConvolutional)inputType;
        this.inputSize = new int[]{(int)cnnType.getHeight(), (int)cnnType.getWidth()};
        this.computeOutputSize();
        return InputTypeUtil.getOutputTypeCnnLayers(inputType, this.kernel, this.stride, this.padding, new int[]{1, 1}, this.cm, this.nOut, layerIndex, this.getLayerName(), LocallyConnected2D.class);
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (this.nIn <= 0L || override) {
            InputType.InputTypeConvolutional c = (InputType.InputTypeConvolutional)inputType;
            this.nIn = c.getChannels();
        }
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        return InputTypeUtil.getPreProcessorForInputTypeCnnLayers(inputType, this.getLayerName());
    }

    @Override
    public void defineParameters(SDLayerParams params) {
        params.clear();
        long[] weightsShape = new long[]{this.outputSize[0] * this.outputSize[1], this.featureDim, this.nOut};
        params.addWeightParam("W", weightsShape);
        if (this.hasBias) {
            long[] biasShape = new long[]{1L, this.nOut};
            params.addBiasParam("b", biasShape);
        }
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
        try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
            for (Map.Entry<String, INDArray> e : params.entrySet()) {
                if ("b".equals(e.getKey())) {
                    e.getValue().assign((Number)0);
                    continue;
                }
                double fanIn = this.nIn * (long)this.kernel[0] * (long)this.kernel[1];
                double fanOut = (double)(this.nOut * (long)this.kernel[0] * (long)this.kernel[1]) / ((double)this.stride[0] * (double)this.stride[1]);
                WeightInitUtil.initWeights(fanIn, fanOut, e.getValue().shape(), this.weightInit, null, 'c', e.getValue());
            }
        }
    }

    @Override
    public SDVariable defineLayer(SameDiff sameDiff, SDVariable layerInput, Map<String, SDVariable> paramTable) {
        SDVariable w = paramTable.get("W");
        long[] inputShape = layerInput.getShape();
        long miniBatch = inputShape[0];
        int outH = this.outputSize[0];
        int outW = this.outputSize[1];
        int sH = this.stride[0];
        int sW = this.stride[1];
        int kH = this.kernel[0];
        int kW = this.kernel[1];
        SDVariable[] inputArray = new SDVariable[outH * outW];
        for (int i = 0; i < outH; ++i) {
            for (int j = 0; j < outW; ++j) {
                SDVariable slice = layerInput.get(new SDIndex[]{SDIndex.all(), SDIndex.all(), SDIndex.interval((Integer)(i * sH), (Integer)(i * sH + kH)), SDIndex.interval((Integer)(j * sW), (Integer)(j * sW + kW))});
                inputArray[i * outH + j] = sameDiff.reshape(slice, new long[]{1L, miniBatch, this.featureDim});
            }
        }
        SDVariable concatOutput = sameDiff.concat(0, inputArray);
        SDVariable mmulResult = sameDiff.mmul(concatOutput, w);
        SDVariable reshapeResult = sameDiff.reshape(mmulResult, new long[]{outH, outW, miniBatch, this.nOut});
        SDVariable permutedResult = sameDiff.permute(reshapeResult, new int[]{2, 3, 0, 1});
        SDVariable b = sameDiff.zero("bias", new long[]{1L, this.nOut});
        if (this.hasBias) {
            b = paramTable.get("b");
        }
        SDVariable biasAddedResult = sameDiff.biasAdd(permutedResult, b);
        return this.activation.asSameDiff("out", sameDiff, biasAddedResult);
    }

    @Override
    public void applyGlobalConfigToLayer(NeuralNetConfiguration.Builder globalConfig) {
        if (this.activation == null) {
            this.activation = SameDiffLayerUtils.fromIActivation(globalConfig.getActivationFn());
        }
        if (this.cm == null) {
            this.cm = globalConfig.getConvolutionMode();
        }
    }

    public long getNIn() {
        return this.nIn;
    }

    public long getNOut() {
        return this.nOut;
    }

    public Activation getActivation() {
        return this.activation;
    }

    public int[] getKernel() {
        return this.kernel;
    }

    public int[] getStride() {
        return this.stride;
    }

    public int[] getPadding() {
        return this.padding;
    }

    public ConvolutionMode getCm() {
        return this.cm;
    }

    public int[] getDilation() {
        return this.dilation;
    }

    public boolean isHasBias() {
        return this.hasBias;
    }

    public int[] getInputSize() {
        return this.inputSize;
    }

    public int[] getOutputSize() {
        return this.outputSize;
    }

    public int getFeatureDim() {
        return this.featureDim;
    }

    public void setNIn(long nIn) {
        this.nIn = nIn;
    }

    public void setNOut(long nOut) {
        this.nOut = nOut;
    }

    public void setActivation(Activation activation) {
        this.activation = activation;
    }

    public void setKernel(int[] kernel) {
        this.kernel = kernel;
    }

    public void setStride(int[] stride) {
        this.stride = stride;
    }

    public void setPadding(int[] padding) {
        this.padding = padding;
    }

    public void setCm(ConvolutionMode cm) {
        this.cm = cm;
    }

    public void setDilation(int[] dilation) {
        this.dilation = dilation;
    }

    public void setHasBias(boolean hasBias) {
        this.hasBias = hasBias;
    }

    public void setInputSize(int[] inputSize) {
        this.inputSize = inputSize;
    }

    public void setOutputSize(int[] outputSize) {
        this.outputSize = outputSize;
    }

    public void setFeatureDim(int featureDim) {
        this.featureDim = featureDim;
    }

    @Override
    public String toString() {
        return "LocallyConnected2D(nIn=" + this.getNIn() + ", nOut=" + this.getNOut() + ", activation=" + this.getActivation() + ", kernel=" + Arrays.toString(this.getKernel()) + ", stride=" + Arrays.toString(this.getStride()) + ", padding=" + Arrays.toString(this.getPadding()) + ", cm=" + (Object)((Object)this.getCm()) + ", dilation=" + Arrays.toString(this.getDilation()) + ", hasBias=" + this.isHasBias() + ", inputSize=" + Arrays.toString(this.getInputSize()) + ", outputSize=" + Arrays.toString(this.getOutputSize()) + ", featureDim=" + this.getFeatureDim() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LocallyConnected2D)) {
            return false;
        }
        LocallyConnected2D other = (LocallyConnected2D)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.getNIn() != other.getNIn()) {
            return false;
        }
        if (this.getNOut() != other.getNOut()) {
            return false;
        }
        Activation this$activation = this.getActivation();
        Activation other$activation = other.getActivation();
        if (this$activation == null ? other$activation != null : !this$activation.equals(other$activation)) {
            return false;
        }
        if (!Arrays.equals(this.getKernel(), other.getKernel())) {
            return false;
        }
        if (!Arrays.equals(this.getStride(), other.getStride())) {
            return false;
        }
        if (!Arrays.equals(this.getPadding(), other.getPadding())) {
            return false;
        }
        ConvolutionMode this$cm = this.getCm();
        ConvolutionMode other$cm = other.getCm();
        if (this$cm == null ? other$cm != null : !((Object)((Object)this$cm)).equals((Object)other$cm)) {
            return false;
        }
        if (!Arrays.equals(this.getDilation(), other.getDilation())) {
            return false;
        }
        if (this.isHasBias() != other.isHasBias()) {
            return false;
        }
        if (!Arrays.equals(this.getInputSize(), other.getInputSize())) {
            return false;
        }
        if (!Arrays.equals(this.getOutputSize(), other.getOutputSize())) {
            return false;
        }
        return this.getFeatureDim() == other.getFeatureDim();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof LocallyConnected2D;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $nIn = this.getNIn();
        result = result * 59 + (int)($nIn >>> 32 ^ $nIn);
        long $nOut = this.getNOut();
        result = result * 59 + (int)($nOut >>> 32 ^ $nOut);
        Activation $activation = this.getActivation();
        result = result * 59 + ($activation == null ? 43 : $activation.hashCode());
        result = result * 59 + Arrays.hashCode(this.getKernel());
        result = result * 59 + Arrays.hashCode(this.getStride());
        result = result * 59 + Arrays.hashCode(this.getPadding());
        ConvolutionMode $cm = this.getCm();
        result = result * 59 + ($cm == null ? 43 : ((Object)((Object)$cm)).hashCode());
        result = result * 59 + Arrays.hashCode(this.getDilation());
        result = result * 59 + (this.isHasBias() ? 79 : 97);
        result = result * 59 + Arrays.hashCode(this.getInputSize());
        result = result * 59 + Arrays.hashCode(this.getOutputSize());
        result = result * 59 + this.getFeatureDim();
        return result;
    }

    public static class Builder
    extends SameDiffLayer.Builder<Builder> {
        private int nIn;
        private int nOut;
        private Activation activation = Activation.TANH;
        private int[] kernel = new int[]{2, 2};
        private int[] stride = new int[]{1, 1};
        private int[] padding = new int[]{0, 0};
        private int[] dilation = new int[]{1, 1};
        private int[] inputSize;
        private ConvolutionMode cm = ConvolutionMode.Same;
        private boolean hasBias = false;

        public Builder nIn(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public Builder nOut(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public Builder activation(Activation activation) {
            this.activation = activation;
            return this;
        }

        public Builder kernelSize(int ... k) {
            this.kernel = k;
            return this;
        }

        public Builder stride(int ... s) {
            this.stride = s;
            return this;
        }

        public Builder padding(int ... p) {
            this.padding = p;
            return this;
        }

        public Builder convolutionMode(ConvolutionMode cm) {
            this.cm = cm;
            return this;
        }

        public Builder dilation(int ... d) {
            this.dilation = d;
            return this;
        }

        public Builder hasBias(boolean hasBias) {
            this.hasBias = hasBias;
            return this;
        }

        public Builder setInputSize(int ... inputSize) {
            Preconditions.checkState((inputSize.length == 2 ? 1 : 0) != 0, (String)("Input size argument of a locally connectedlayer has to have length 2, got " + inputSize.length));
            this.inputSize = inputSize;
            return this;
        }

        @Override
        public LocallyConnected2D build() {
            ConvolutionUtils.validateConvolutionModePadding(this.cm, this.padding);
            ConvolutionUtils.validateCnnKernelStridePadding(this.kernel, this.stride, this.padding);
            return new LocallyConnected2D(this);
        }
    }
}

