/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.variational;

import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BasePretrainNetwork;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.layers.variational.GaussianReconstructionDistribution;
import org.deeplearning4j.nn.conf.layers.variational.LossFunctionWrapper;
import org.deeplearning4j.nn.conf.layers.variational.ReconstructionDistribution;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.params.VariationalAutoencoderParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.util.ArrayUtil;

public class VariationalAutoencoder
extends BasePretrainNetwork {
    private int[] encoderLayerSizes;
    private int[] decoderLayerSizes;
    private ReconstructionDistribution outputDistribution;
    private IActivation pzxActivationFn;
    private int numSamples;

    private VariationalAutoencoder(Builder builder) {
        super(builder);
        this.encoderLayerSizes = builder.encoderLayerSizes;
        this.decoderLayerSizes = builder.decoderLayerSizes;
        this.outputDistribution = builder.outputDistribution;
        this.pzxActivationFn = builder.pzxActivationFn;
        this.numSamples = builder.numSamples;
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) {
        LayerValidation.assertNInNOutSet("VariationalAutoencoder", this.getLayerName(), layerIndex, this.getNIn(), this.getNOut());
        org.deeplearning4j.nn.layers.variational.VariationalAutoencoder ret = new org.deeplearning4j.nn.layers.variational.VariationalAutoencoder(conf);
        ret.setListeners(trainingListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        return VariationalAutoencoderParamInitializer.getInstance();
    }

    @Override
    public double getL1ByParam(String paramName) {
        if (paramName.endsWith("b")) {
            return this.l1Bias;
        }
        return this.l1;
    }

    @Override
    public double getL2ByParam(String paramName) {
        if (paramName.endsWith("b")) {
            return this.l2Bias;
        }
        return this.l2;
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        if (paramName.startsWith("d")) {
            return true;
        }
        if (paramName.startsWith("pZXLogStd2")) {
            return true;
        }
        return paramName.startsWith("pXZ");
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        InputType outputType = this.getOutputType(-1, inputType);
        int actElementsPerEx = outputType.arrayElementsPerExample();
        int numParams = this.initializer().numParams(this);
        int updaterStateSize = (int)this.getIUpdater().stateSize((long)numParams);
        int inferenceWorkingMemSizePerEx = 0;
        for (int i = 1; i < this.encoderLayerSizes.length; ++i) {
            inferenceWorkingMemSizePerEx += this.encoderLayerSizes[i];
        }
        int decoderFwdSizeWorking = 4 * this.nOut;
        decoderFwdSizeWorking += this.numSamples * (2 * this.nOut + ArrayUtil.sum((int[])this.getDecoderLayerSizes()));
        int trainWorkingMemSize = 2 * (inferenceWorkingMemSizePerEx + (decoderFwdSizeWorking += this.nOut));
        if (this.getIDropout() != null) {
            trainWorkingMemSize += inputType.arrayElementsPerExample();
        }
        return new LayerMemoryReport.Builder(this.layerName, VariationalAutoencoder.class, inputType, outputType).standardMemory(numParams, updaterStateSize).workingMemory(0L, (long)inferenceWorkingMemSizePerEx, 0L, trainWorkingMemSize).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    public int[] getEncoderLayerSizes() {
        return this.encoderLayerSizes;
    }

    public int[] getDecoderLayerSizes() {
        return this.decoderLayerSizes;
    }

    public ReconstructionDistribution getOutputDistribution() {
        return this.outputDistribution;
    }

    public IActivation getPzxActivationFn() {
        return this.pzxActivationFn;
    }

    public int getNumSamples() {
        return this.numSamples;
    }

    public void setEncoderLayerSizes(int[] encoderLayerSizes) {
        this.encoderLayerSizes = encoderLayerSizes;
    }

    public void setDecoderLayerSizes(int[] decoderLayerSizes) {
        this.decoderLayerSizes = decoderLayerSizes;
    }

    public void setOutputDistribution(ReconstructionDistribution outputDistribution) {
        this.outputDistribution = outputDistribution;
    }

    public void setPzxActivationFn(IActivation pzxActivationFn) {
        this.pzxActivationFn = pzxActivationFn;
    }

    public void setNumSamples(int numSamples) {
        this.numSamples = numSamples;
    }

    @Override
    public String toString() {
        return "VariationalAutoencoder(encoderLayerSizes=" + Arrays.toString(this.getEncoderLayerSizes()) + ", decoderLayerSizes=" + Arrays.toString(this.getDecoderLayerSizes()) + ", outputDistribution=" + this.getOutputDistribution() + ", pzxActivationFn=" + this.getPzxActivationFn() + ", numSamples=" + this.getNumSamples() + ")";
    }

    public VariationalAutoencoder() {
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof VariationalAutoencoder)) {
            return false;
        }
        VariationalAutoencoder other = (VariationalAutoencoder)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (!Arrays.equals(this.getEncoderLayerSizes(), other.getEncoderLayerSizes())) {
            return false;
        }
        if (!Arrays.equals(this.getDecoderLayerSizes(), other.getDecoderLayerSizes())) {
            return false;
        }
        ReconstructionDistribution this$outputDistribution = this.getOutputDistribution();
        ReconstructionDistribution other$outputDistribution = other.getOutputDistribution();
        if (this$outputDistribution == null ? other$outputDistribution != null : !this$outputDistribution.equals(other$outputDistribution)) {
            return false;
        }
        IActivation this$pzxActivationFn = this.getPzxActivationFn();
        IActivation other$pzxActivationFn = other.getPzxActivationFn();
        if (this$pzxActivationFn == null ? other$pzxActivationFn != null : !this$pzxActivationFn.equals(other$pzxActivationFn)) {
            return false;
        }
        return this.getNumSamples() == other.getNumSamples();
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof VariationalAutoencoder;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + Arrays.hashCode(this.getEncoderLayerSizes());
        result = result * 59 + Arrays.hashCode(this.getDecoderLayerSizes());
        ReconstructionDistribution $outputDistribution = this.getOutputDistribution();
        result = result * 59 + ($outputDistribution == null ? 43 : $outputDistribution.hashCode());
        IActivation $pzxActivationFn = this.getPzxActivationFn();
        result = result * 59 + ($pzxActivationFn == null ? 43 : $pzxActivationFn.hashCode());
        result = result * 59 + this.getNumSamples();
        return result;
    }

    public static class Builder
    extends BasePretrainNetwork.Builder<Builder> {
        private int[] encoderLayerSizes = new int[]{100};
        private int[] decoderLayerSizes = new int[]{100};
        private ReconstructionDistribution outputDistribution = new GaussianReconstructionDistribution(Activation.TANH);
        private IActivation pzxActivationFn = new ActivationIdentity();
        private int numSamples = 1;

        public Builder encoderLayerSizes(int ... encoderLayerSizes) {
            if (encoderLayerSizes == null || encoderLayerSizes.length < 1) {
                throw new IllegalArgumentException("Encoder layer sizes array must have length > 0");
            }
            this.encoderLayerSizes = encoderLayerSizes;
            return this;
        }

        public Builder decoderLayerSizes(int ... decoderLayerSizes) {
            if (decoderLayerSizes == null || decoderLayerSizes.length < 1) {
                throw new IllegalArgumentException("Decoder layer sizes array must have length > 0");
            }
            this.decoderLayerSizes = decoderLayerSizes;
            return this;
        }

        public Builder reconstructionDistribution(ReconstructionDistribution distribution) {
            this.outputDistribution = distribution;
            return this;
        }

        public Builder lossFunction(IActivation outputActivationFn, LossFunctions.LossFunction lossFunction) {
            return this.lossFunction(outputActivationFn, lossFunction.getILossFunction());
        }

        public Builder lossFunction(Activation outputActivationFn, LossFunctions.LossFunction lossFunction) {
            return this.lossFunction(outputActivationFn.getActivationFunction(), lossFunction.getILossFunction());
        }

        public Builder lossFunction(IActivation outputActivationFn, ILossFunction lossFunction) {
            return this.reconstructionDistribution(new LossFunctionWrapper(outputActivationFn, lossFunction));
        }

        public Builder pzxActivationFn(IActivation activationFunction) {
            this.pzxActivationFn = activationFunction;
            return this;
        }

        public Builder pzxActivationFunction(Activation activation) {
            return this.pzxActivationFn(activation.getActivationFunction());
        }

        @Override
        public Builder nOut(int nOut) {
            super.nOut(nOut);
            return this;
        }

        public Builder numSamples(int numSamples) {
            this.numSamples = numSamples;
            return this;
        }

        @Override
        public VariationalAutoencoder build() {
            return new VariationalAutoencoder(this);
        }
    }
}

