/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.weights;

import java.util.Arrays;
import org.deeplearning4j.nn.weights.IWeightInit;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class WeightInitIdentity
implements IWeightInit {
    private Double scale;

    public WeightInitIdentity(@JsonProperty(value="scale") Double scale) {
        this.scale = scale;
    }

    @Override
    public INDArray init(double fanIn, double fanOut, long[] shape, char order, INDArray paramView) {
        if (shape[0] != shape[1]) {
            throw new IllegalStateException("Cannot use IDENTITY init with parameters of shape " + Arrays.toString(shape) + ": weights must be a square matrix for identity");
        }
        switch (shape.length) {
            case 2: {
                return this.setIdentity2D(shape, order, paramView);
            }
            case 3: 
            case 4: 
            case 5: {
                return this.setIdentityConv(shape, order, paramView);
            }
        }
        throw new IllegalStateException("Identity mapping for " + shape.length + " dimensions not defined!");
    }

    private INDArray setIdentity2D(long[] shape, char order, INDArray paramView) {
        INDArray ret = order == Nd4j.order().charValue() ? Nd4j.eye((long)shape[0]) : Nd4j.createUninitialized((long[])shape, (char)order).assign(Nd4j.eye((long)shape[0]));
        if (this.scale != null) {
            ret.muli((Number)this.scale);
        }
        INDArray flat = Nd4j.toFlattened((char)order, (INDArray[])new INDArray[]{ret});
        paramView.assign(flat);
        return paramView.reshape(order, shape);
    }

    private INDArray setIdentityConv(long[] shape, char order, INDArray paramView) {
        INDArrayIndex[] indArrayIndices = new INDArrayIndex[shape.length];
        for (int i = 2; i < shape.length; ++i) {
            if (shape[i] % 2L == 0L) {
                throw new IllegalStateException("Cannot use IDENTITY init with parameters of shape " + Arrays.toString(shape) + "! Must have odd sized kernels!");
            }
            indArrayIndices[i] = NDArrayIndex.point((long)(shape[i] / 2L));
        }
        paramView.assign((Number)0);
        INDArray params = paramView.reshape(order, shape);
        int i = 0;
        while ((long)i < shape[0]) {
            indArrayIndices[0] = NDArrayIndex.point((long)i);
            indArrayIndices[1] = NDArrayIndex.point((long)i);
            params.put(indArrayIndices, Nd4j.ones((int[])new int[]{1}));
            ++i;
        }
        if (this.scale != null) {
            params.muli((Number)this.scale);
        }
        return params;
    }

    public Double getScale() {
        return this.scale;
    }

    public void setScale(Double scale) {
        this.scale = scale;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof WeightInitIdentity)) {
            return false;
        }
        WeightInitIdentity other = (WeightInitIdentity)o;
        if (!other.canEqual(this)) {
            return false;
        }
        Double this$scale = this.getScale();
        Double other$scale = other.getScale();
        return !(this$scale == null ? other$scale != null : !((Object)this$scale).equals(other$scale));
    }

    protected boolean canEqual(Object other) {
        return other instanceof WeightInitIdentity;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        Double $scale = this.getScale();
        result = result * 59 + ($scale == null ? 43 : ((Object)$scale).hashCode());
        return result;
    }

    public String toString() {
        return "WeightInitIdentity(scale=" + this.getScale() + ")";
    }

    public WeightInitIdentity() {
    }
}

