/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.zoo.model.helper;

import java.util.Map;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ActivationLayer;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.conf.layers.SeparableConvolution2D;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.layers.ZeroPaddingLayer;
import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.primitives.Pair;

public class NASNetHelper {
    public static String sepConvBlock(ComputationGraphConfiguration.GraphBuilder graphBuilder, int filters, int kernelSize, int stride, String blockId, String input) {
        String prefix = "sepConvBlock" + blockId;
        graphBuilder.addLayer(prefix + "_act", (Layer)new ActivationLayer(Activation.RELU), new String[]{input}).addLayer(prefix + "_sepconv1", (Layer)((SeparableConvolution2D.Builder)((SeparableConvolution2D.Builder)((SeparableConvolution2D.Builder)new SeparableConvolution2D.Builder(new int[]{kernelSize, kernelSize}).stride(new int[]{stride, stride}).nOut(filters)).hasBias(false)).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_act"}).addLayer(prefix + "_conv1_bn", (Layer)new BatchNormalization.Builder().eps(0.001).gamma(0.9997).build(), new String[]{prefix + "_sepconv1"}).addLayer(prefix + "_act2", (Layer)new ActivationLayer(Activation.RELU), new String[]{prefix + "_conv1_bn"}).addLayer(prefix + "_sepconv2", (Layer)((SeparableConvolution2D.Builder)((SeparableConvolution2D.Builder)((SeparableConvolution2D.Builder)new SeparableConvolution2D.Builder(new int[]{kernelSize, kernelSize}).stride(new int[]{stride, stride}).nOut(filters)).hasBias(false)).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_act2"}).addLayer(prefix + "_conv2_bn", (Layer)new BatchNormalization.Builder().eps(0.001).gamma(0.9997).build(), new String[]{prefix + "_sepconv2"});
        return prefix + "_conv2_bn";
    }

    public static String adjustBlock(ComputationGraphConfiguration.GraphBuilder graphBuilder, int filters, String blockId, String input) {
        return NASNetHelper.adjustBlock(graphBuilder, filters, blockId, input, null);
    }

    public static String adjustBlock(ComputationGraphConfiguration.GraphBuilder graphBuilder, int filters, String blockId, String input, String inputToMatch) {
        int[] inputShape;
        Map layerActivationTypes;
        int[] shapeToMatch;
        String prefix = "adjustBlock" + blockId;
        String outputName = input;
        if (inputToMatch == null) {
            inputToMatch = input;
        }
        if ((shapeToMatch = ((InputType)(layerActivationTypes = graphBuilder.getLayerActivationTypes()).get(inputToMatch)).getShape())[1] != (inputShape = ((InputType)layerActivationTypes.get(input)).getShape())[1]) {
            graphBuilder.addLayer(prefix + "_relu1", (Layer)new ActivationLayer(Activation.RELU), new String[]{input}).addLayer(prefix + "_avgpool1", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{1, 1}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Truncate)).build(), new String[]{prefix + "_relu1"}).addLayer(prefix + "_conv1", (Layer)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut((int)Math.floor(filters / 2))).hasBias(false)).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_avg_pool_1"}).addLayer(prefix + "_zeropad1", (Layer)new ZeroPaddingLayer(0, 1), new String[]{prefix + "_relu1"}).addLayer(prefix + "_crop1", (Layer)new Cropping2D(1, 0), new String[]{prefix + "_zeropad_1"}).addLayer(prefix + "_avgpool2", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{1, 1}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Truncate)).build(), new String[]{prefix + "_crop1"}).addLayer(prefix + "_conv2", (Layer)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut((int)Math.floor(filters / 2))).hasBias(false)).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_avgpool2"}).addVertex(prefix + "_concat1", (GraphVertex)new MergeVertex(), new String[]{prefix + "_conv1", prefix + "_conv2"}).addLayer(prefix + "_bn1", (Layer)new BatchNormalization.Builder().eps(0.001).gamma(0.9997).build(), new String[]{prefix + "_concat1"});
            outputName = prefix + "_bn1";
        }
        if (inputShape[3] != filters) {
            graphBuilder.addLayer(prefix + "_projection_relu", (Layer)new ActivationLayer(Activation.RELU), new String[]{outputName}).addLayer(prefix + "_projection_conv", (Layer)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(filters)).hasBias(false)).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_projection_relu"}).addLayer(prefix + "_projection_bn", (Layer)new BatchNormalization.Builder().eps(0.001).gamma(0.9997).build(), new String[]{prefix + "_projection_conv"});
            outputName = prefix + "_projection_bn";
        }
        return outputName;
    }

    public static Pair<String, String> normalA(ComputationGraphConfiguration.GraphBuilder graphBuilder, int filters, String blockId, String inputX, String inputP) {
        String prefix = "normalA" + blockId;
        String topAdjust = NASNetHelper.adjustBlock(graphBuilder, filters, prefix, inputP, inputX);
        graphBuilder.addLayer(prefix + "_relu1", (Layer)new ActivationLayer(Activation.RELU), new String[]{topAdjust}).addLayer(prefix + "_conv1", (Layer)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(filters)).hasBias(false)).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_relu1"}).addLayer(prefix + "_bn1", (Layer)new BatchNormalization.Builder().eps(0.001).gamma(0.9997).build(), new String[]{prefix + "_conv1"});
        String left1 = NASNetHelper.sepConvBlock(graphBuilder, filters, 5, 1, prefix + "_left1", prefix + "_bn1");
        String right1 = NASNetHelper.sepConvBlock(graphBuilder, filters, 3, 1, prefix + "_right1", topAdjust);
        graphBuilder.addVertex(prefix + "_add1", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{left1, right1});
        String left2 = NASNetHelper.sepConvBlock(graphBuilder, filters, 5, 1, prefix + "_left2", topAdjust);
        String right2 = NASNetHelper.sepConvBlock(graphBuilder, filters, 3, 1, prefix + "_right2", topAdjust);
        graphBuilder.addVertex(prefix + "_add2", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{left2, right2});
        graphBuilder.addLayer(prefix + "_left3", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_bn1"}).addVertex(prefix + "_add3", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{prefix + "_left3", topAdjust});
        graphBuilder.addLayer(prefix + "_left4", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same)).build(), new String[]{topAdjust}).addLayer(prefix + "_right4", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same)).build(), new String[]{topAdjust}).addVertex(prefix + "_add4", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{prefix + "_left4", prefix + "_right4"});
        String left5 = NASNetHelper.sepConvBlock(graphBuilder, filters, 3, 1, prefix + "_left5", topAdjust);
        graphBuilder.addVertex(prefix + "_add5", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{prefix + "_left5", prefix + "_bn1"});
        graphBuilder.addVertex(prefix, (GraphVertex)new MergeVertex(), new String[]{topAdjust, prefix + "_add1", prefix + "_add2", prefix + "_add3", prefix + "_add4", prefix + "_add5"});
        return new Pair((Object)prefix, (Object)inputX);
    }

    public static Pair<String, String> reductionA(ComputationGraphConfiguration.GraphBuilder graphBuilder, int filters, String blockId, String inputX, String inputP) {
        String prefix = "reductionA" + blockId;
        String topAdjust = NASNetHelper.adjustBlock(graphBuilder, filters, prefix, inputP, inputX);
        graphBuilder.addLayer(prefix + "_relu1", (Layer)new ActivationLayer(Activation.RELU), new String[]{topAdjust}).addLayer(prefix + "_conv1", (Layer)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}).stride(new int[]{1, 1}).nOut(filters)).hasBias(false)).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_relu1"}).addLayer(prefix + "_bn1", (Layer)new BatchNormalization.Builder().eps(0.001).gamma(0.9997).build(), new String[]{prefix + "_conv1"});
        String left1 = NASNetHelper.sepConvBlock(graphBuilder, filters, 5, 2, prefix + "_left1", prefix + "_bn1");
        String right1 = NASNetHelper.sepConvBlock(graphBuilder, filters, 7, 2, prefix + "_right1", topAdjust);
        graphBuilder.addVertex(prefix + "_add1", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{left1, right1});
        graphBuilder.addLayer(prefix + "_left2", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX).kernelSize(new int[]{3, 3}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_bn1"});
        String right2 = NASNetHelper.sepConvBlock(graphBuilder, filters, 3, 1, prefix + "_right2", topAdjust);
        graphBuilder.addVertex(prefix + "_add2", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{prefix + "_left2", right2});
        graphBuilder.addLayer(prefix + "_left3", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_bn1"});
        String right3 = NASNetHelper.sepConvBlock(graphBuilder, filters, 5, 2, prefix + "_right3", topAdjust);
        graphBuilder.addVertex(prefix + "_add3", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{prefix + "_left3", right3});
        graphBuilder.addLayer(prefix + "_left4", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(PoolingType.AVG).kernelSize(new int[]{3, 3}).stride(new int[]{1, 1}).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_add1"}).addVertex(prefix + "_add4", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{prefix + "_add2", prefix + "_left4"});
        String left5 = NASNetHelper.sepConvBlock(graphBuilder, filters, 3, 2, prefix + "_left5", prefix + "_add1");
        graphBuilder.addLayer(prefix + "_right5", (Layer)((SubsamplingLayer.Builder)new SubsamplingLayer.Builder(PoolingType.MAX).kernelSize(new int[]{3, 3}).stride(new int[]{2, 2}).convolutionMode(ConvolutionMode.Same)).build(), new String[]{prefix + "_bn1"}).addVertex(prefix + "_add5", (GraphVertex)new ElementWiseVertex(ElementWiseVertex.Op.Add), new String[]{left5, prefix + "_right5"});
        graphBuilder.addVertex(prefix, (GraphVertex)new MergeVertex(), new String[]{prefix + "_add2", prefix + "_add3", prefix + "_add4", prefix + "_add5"});
        return new Pair((Object)prefix, (Object)inputX);
    }
}

