/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseUpsamplingLayer;
import org.deeplearning4j.nn.conf.layers.Convolution3D;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.ValidationUtils;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

public class Upsampling3D
extends BaseUpsamplingLayer {
    protected int[] size;
    protected Convolution3D.DataFormat dataFormat = Convolution3D.DataFormat.NCDHW;

    protected Upsampling3D(Builder builder) {
        super(builder);
        this.size = builder.size;
        this.dataFormat = builder.dataFormat;
    }

    @Override
    public Upsampling3D clone() {
        return (Upsampling3D)super.clone();
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D ret = new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D(conf, networkDataType);
        ret.setListeners(iterationListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
            throw new IllegalStateException("Invalid input for Upsampling 3D layer (layer name=\"" + this.getLayerName() + "\"): Expected CNN3D input, got " + inputType);
        }
        InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D)inputType;
        int inHeight = (int)i.getHeight();
        int inWidth = (int)i.getWidth();
        int inDepth = (int)i.getDepth();
        int inChannels = (int)i.getChannels();
        return InputType.convolutional3D(this.size[0] * inDepth, this.size[1] * inHeight, this.size[2] * inWidth, inChannels);
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        if (inputType == null) {
            throw new IllegalStateException("Invalid input for Upsampling 3D layer (layer name=\"" + this.getLayerName() + "\"): input is null");
        }
        return InputTypeUtil.getPreProcessorForInputTypeCnn3DLayers(inputType, this.getLayerName());
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        long im2colSizePerEx;
        InputType.InputTypeConvolutional3D c = (InputType.InputTypeConvolutional3D)inputType;
        InputType.InputTypeConvolutional3D outputType = (InputType.InputTypeConvolutional3D)this.getOutputType(-1, inputType);
        long trainingWorkingSizePerEx = im2colSizePerEx = c.getChannels() & outputType.getDepth() * outputType.getHeight() * outputType.getWidth() * (long)this.size[0] * (long)this.size[1] * (long)this.size[2];
        if (this.getIDropout() != null) {
            trainingWorkingSizePerEx += inputType.arrayElementsPerExample();
        }
        return new LayerMemoryReport.Builder(this.layerName, Upsampling3D.class, inputType, outputType).standardMemory(0L, 0L).workingMemory(0L, im2colSizePerEx, 0L, trainingWorkingSizePerEx).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    @Override
    public int[] getSize() {
        return this.size;
    }

    public Convolution3D.DataFormat getDataFormat() {
        return this.dataFormat;
    }

    @Override
    public void setSize(int[] size) {
        this.size = size;
    }

    public void setDataFormat(Convolution3D.DataFormat dataFormat) {
        this.dataFormat = dataFormat;
    }

    public Upsampling3D() {
    }

    @Override
    public String toString() {
        return "Upsampling3D(super=" + super.toString() + ", size=" + Arrays.toString(this.getSize()) + ", dataFormat=" + (Object)((Object)this.getDataFormat()) + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Upsampling3D)) {
            return false;
        }
        Upsampling3D other = (Upsampling3D)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (!Arrays.equals(this.getSize(), other.getSize())) {
            return false;
        }
        Convolution3D.DataFormat this$dataFormat = this.getDataFormat();
        Convolution3D.DataFormat other$dataFormat = other.getDataFormat();
        return !(this$dataFormat == null ? other$dataFormat != null : !((Object)((Object)this$dataFormat)).equals((Object)other$dataFormat));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof Upsampling3D;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + Arrays.hashCode(this.getSize());
        Convolution3D.DataFormat $dataFormat = this.getDataFormat();
        result = result * 59 + ($dataFormat == null ? 43 : ((Object)((Object)$dataFormat)).hashCode());
        return result;
    }

    public static class Builder
    extends BaseUpsamplingLayer.UpsamplingBuilder<Builder> {
        protected Convolution3D.DataFormat dataFormat = Convolution3D.DataFormat.NCDHW;

        public Builder(int size) {
            super(new int[]{size, size, size});
        }

        public Builder(@NonNull Convolution3D.DataFormat dataFormat, int size) {
            super(new int[]{size, size, size});
            if (dataFormat == null) {
                throw new NullPointerException("dataFormat is marked @NonNull but is null");
            }
            this.dataFormat = dataFormat;
        }

        public Builder dataFormat(@NonNull Convolution3D.DataFormat dataFormat) {
            if (dataFormat == null) {
                throw new NullPointerException("dataFormat is marked @NonNull but is null");
            }
            this.dataFormat = dataFormat;
            return this;
        }

        public Builder size(int size) {
            this.setSize(size, size, size);
            return this;
        }

        public Builder size(int[] size) {
            Preconditions.checkArgument((size.length == 3 ? 1 : 0) != 0);
            this.setSize(size);
            return this;
        }

        @Override
        public Upsampling3D build() {
            return new Upsampling3D(this);
        }

        @Override
        public void setSize(int ... size) {
            this.size = ValidationUtils.validate3NonNegative(size, "size");
        }

        public Builder() {
        }
    }
}

