/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.layers.convolution;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import lombok.NonNull;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalizationDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.LocalResponseNormalizationConfig;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class LocalResponseNormalization
extends DynamicCustomOp {
    private static final Logger log = LoggerFactory.getLogger(LocalResponseNormalization.class);
    protected LocalResponseNormalizationConfig config;

    public LocalResponseNormalization(SameDiff sameDiff, SDVariable[] inputFunctions, boolean inPlace, LocalResponseNormalizationConfig config) {
        super(null, sameDiff, inputFunctions, inPlace);
        this.config = config;
        this.addArgs();
    }

    public LocalResponseNormalization(@NonNull INDArray input, INDArray output, @NonNull LocalResponseNormalizationConfig config) {
        super(new INDArray[]{input}, LocalResponseNormalization.wrapOrNull(output));
        if (input == null) {
            throw new NullPointerException("input is marked @NonNull but is null");
        }
        if (config == null) {
            throw new NullPointerException("config is marked @NonNull but is null");
        }
        this.config = config;
        this.addArgs();
    }

    @Override
    public Map<String, Object> propertiesForFunction() {
        return this.config.toProperties();
    }

    private void addArgs() {
        this.addTArgument(this.config.getBias());
        this.addTArgument(this.config.getAlpha());
        this.addTArgument(this.config.getBeta());
        this.addIArgument(this.config.getDepth());
    }

    @Override
    public boolean isConfigProperties() {
        return true;
    }

    @Override
    public String configFieldName() {
        return "config";
    }

    @Override
    public String opName() {
        return "lrn";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        LocalResponseNormalizationConfig localResponseNormalizationConfig;
        AttrValue aAlpha = nodeDef.getAttrOrThrow("alpha");
        AttrValue aBeta = nodeDef.getAttrOrThrow("beta");
        AttrValue aBias = nodeDef.getAttrOrThrow("bias");
        AttrValue aDepth = nodeDef.getAttrOrThrow("depth_radius");
        double alpha = aAlpha.getF();
        double beta = aBeta.getF();
        double bias = aBias.getF();
        int depth = (int)aDepth.getI();
        this.config = localResponseNormalizationConfig = LocalResponseNormalizationConfig.builder().alpha(alpha).beta(beta).bias(bias).depth(depth).build();
        this.addArgs();
    }

    @Override
    public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
        LocalResponseNormalizationConfig localResponseNormalizationConfig;
        Onnx.AttributeProto aAlpha = attributesForNode.get("alpha");
        Onnx.AttributeProto aBeta = attributesForNode.get("beta");
        Onnx.AttributeProto aBias = attributesForNode.get("bias");
        Onnx.AttributeProto aDepth = attributesForNode.get("size");
        float alpha = aAlpha.getF();
        float beta = aBeta.getF();
        float bias = aBias.getF();
        float depth = aDepth.getF();
        this.config = localResponseNormalizationConfig = LocalResponseNormalizationConfig.builder().alpha(alpha).beta(beta).bias(bias).depth((int)depth).build();
        this.addArgs();
    }

    @Override
    public Map<String, Map<String, PropertyMapping>> mappingsForFunction() {
        HashMap<String, Map<String, PropertyMapping>> ret = new HashMap<String, Map<String, PropertyMapping>>();
        PropertyMapping depthMapping = PropertyMapping.builder().tfAttrName("depth_radius").propertyNames(new String[]{"depth"}).onnxAttrName("size").build();
        PropertyMapping alphaMapping = PropertyMapping.builder().tfAttrName("alpha").onnxAttrName("alpha").propertyNames(new String[]{"alpha"}).build();
        PropertyMapping betaMapping = PropertyMapping.builder().tfAttrName("beta").onnxAttrName("beta").propertyNames(new String[]{"beta"}).build();
        PropertyMapping biasMapping = PropertyMapping.builder().tfAttrName("bias").onnxAttrName("bias").propertyNames(new String[]{"bias"}).build();
        HashMap<String, PropertyMapping> map = new HashMap<String, PropertyMapping>();
        map.put("depth", depthMapping);
        map.put("alpha", alphaMapping);
        map.put("beta", betaMapping);
        map.put("bias", biasMapping);
        ret.put(this.tensorflowName(), map);
        ret.put(this.onnxName(), map);
        return ret;
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        SDVariable[] gradFnInputs = new SDVariable[]{this.arg(), f1.get(0)};
        LocalResponseNormalizationDerivative lrnGrad = LocalResponseNormalizationDerivative.derivativeBuilder().inPlace(this.inPlace).sameDiff(this.sameDiff).inputFunctions(gradFnInputs).config(this.config).build();
        return Collections.singletonList(lrnGrad.outputVariable());
    }

    @Override
    public String onnxName() {
        return "LRN";
    }

    @Override
    public String tensorflowName() {
        return "LRN";
    }

    @Override
    public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes) {
        Preconditions.checkState((boolean)inputDataTypes.get(0).isFPType(), (String)"Input 0 should be a floating point type for %s, got %s", this.getClass(), (Object)inputDataTypes.get(0));
        return Collections.singletonList(inputDataTypes.get(0));
    }

    public static LocalResponseNormalizationBuilder sameDiffBuilder() {
        return new LocalResponseNormalizationBuilder();
    }

    public LocalResponseNormalizationConfig getConfig() {
        return this.config;
    }

    public LocalResponseNormalization() {
    }

    public static class LocalResponseNormalizationBuilder {
        private SameDiff sameDiff;
        private SDVariable[] inputFunctions;
        private boolean inPlace;
        private LocalResponseNormalizationConfig config;

        LocalResponseNormalizationBuilder() {
        }

        public LocalResponseNormalizationBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public LocalResponseNormalizationBuilder inputFunctions(SDVariable[] inputFunctions) {
            this.inputFunctions = inputFunctions;
            return this;
        }

        public LocalResponseNormalizationBuilder inPlace(boolean inPlace) {
            this.inPlace = inPlace;
            return this;
        }

        public LocalResponseNormalizationBuilder config(LocalResponseNormalizationConfig config) {
            this.config = config;
            return this;
        }

        public LocalResponseNormalization build() {
            return new LocalResponseNormalization(this.sameDiff, this.inputFunctions, this.inPlace, this.config);
        }

        public String toString() {
            return "LocalResponseNormalization.LocalResponseNormalizationBuilder(sameDiff=" + this.sameDiff + ", inputFunctions=" + Arrays.deepToString(this.inputFunctions) + ", inPlace=" + this.inPlace + ", config=" + this.config + ")";
        }
    }
}

