/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.zoo.model;

import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.graph.MergeVertex;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.zoo.ModelMetaData;
import org.deeplearning4j.zoo.PretrainedType;
import org.deeplearning4j.zoo.ZooModel;
import org.deeplearning4j.zoo.ZooType;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.linalg.schedule.ScheduleType;
import org.nd4j.linalg.schedule.StepSchedule;

public class GoogLeNet
extends ZooModel {
    private int[] inputShape = new int[]{3, 224, 224};
    private int numLabels;
    private long seed;
    private WorkspaceMode workspaceMode;
    private ConvolutionLayer.AlgoMode cudnnAlgoMode;

    public GoogLeNet(int numLabels, long seed) {
        this(numLabels, seed, WorkspaceMode.SEPARATE);
    }

    public GoogLeNet(int numLabels, long seed, WorkspaceMode workspaceMode) {
        this.numLabels = numLabels;
        this.seed = seed;
        this.workspaceMode = workspaceMode;
        this.cudnnAlgoMode = workspaceMode == WorkspaceMode.SINGLE ? ConvolutionLayer.AlgoMode.PREFER_FASTEST : ConvolutionLayer.AlgoMode.NO_WORKSPACE;
    }

    @Override
    public String pretrainedUrl(PretrainedType pretrainedType) {
        if (pretrainedType == PretrainedType.IMAGENET) {
            return "http://blob.deeplearning4j.org/models/googlenet_dl4j_inference.zip";
        }
        return null;
    }

    @Override
    public long pretrainedChecksum(PretrainedType pretrainedType) {
        if (pretrainedType == PretrainedType.IMAGENET) {
            return 3337733202L;
        }
        return 0L;
    }

    @Override
    public ZooType zooType() {
        return ZooType.GOOGLENET;
    }

    @Override
    public Class<? extends Model> modelType() {
        return ComputationGraph.class;
    }

    private ConvolutionLayer conv1x1(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{1, 1}, new int[]{1, 1}, new int[]{0, 0}).nIn(in)).nOut(out)).biasInit(bias)).build();
    }

    private ConvolutionLayer c3x3reduce(int in, int out, double bias) {
        return this.conv1x1(in, out, bias);
    }

    private ConvolutionLayer c5x5reduce(int in, int out, double bias) {
        return this.conv1x1(in, out, bias);
    }

    private ConvolutionLayer conv3x3(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{3, 3}, new int[]{1, 1}, new int[]{1, 1}).nIn(in)).nOut(out)).biasInit(bias)).build();
    }

    private ConvolutionLayer conv5x5(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{5, 5}, new int[]{1, 1}, new int[]{2, 2}).nIn(in)).nOut(out)).biasInit(bias)).build();
    }

    private ConvolutionLayer conv7x7(int in, int out, double bias) {
        return ((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)((ConvolutionLayer.Builder)new ConvolutionLayer.Builder(new int[]{7, 7}, new int[]{2, 2}, new int[]{3, 3}).nIn(in)).nOut(out)).biasInit(bias)).build();
    }

    private SubsamplingLayer avgPool7x7(int stride) {
        return new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.AVG, new int[]{7, 7}, new int[]{1, 1}).build();
    }

    private SubsamplingLayer maxPool3x3(int stride) {
        return new SubsamplingLayer.Builder(new int[]{3, 3}, new int[]{stride, stride}, new int[]{1, 1}).build();
    }

    private DenseLayer fullyConnected(int in, int out, double dropOut) {
        return ((DenseLayer.Builder)((DenseLayer.Builder)((DenseLayer.Builder)new DenseLayer.Builder().nIn(in)).nOut(out)).dropOut(dropOut)).build();
    }

    private ComputationGraphConfiguration.GraphBuilder inception(ComputationGraphConfiguration.GraphBuilder graph, String name, int inputSize, int[][] config, String inputLayer) {
        graph.addLayer(name + "-cnn1", (Layer)this.conv1x1(inputSize, config[0][0], 0.2), new String[]{inputLayer}).addLayer(name + "-cnn2", (Layer)this.c3x3reduce(inputSize, config[1][0], 0.2), new String[]{inputLayer}).addLayer(name + "-cnn3", (Layer)this.c5x5reduce(inputSize, config[2][0], 0.2), new String[]{inputLayer}).addLayer(name + "-max1", (Layer)this.maxPool3x3(1), new String[]{inputLayer}).addLayer(name + "-cnn4", (Layer)this.conv3x3(config[1][0], config[1][1], 0.2), new String[]{name + "-cnn2"}).addLayer(name + "-cnn5", (Layer)this.conv5x5(config[2][0], config[2][1], 0.2), new String[]{name + "-cnn3"}).addLayer(name + "-cnn6", (Layer)this.conv1x1(inputSize, config[3][0], 0.2), new String[]{name + "-max1"}).addVertex(name + "-depthconcat1", (GraphVertex)new MergeVertex(), new String[]{name + "-cnn1", name + "-cnn4", name + "-cnn5", name + "-cnn6"});
        return graph;
    }

    public ComputationGraphConfiguration conf() {
        ComputationGraphConfiguration.GraphBuilder graph = new NeuralNetConfiguration.Builder().seed(this.seed).activation(Activation.RELU).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater((IUpdater)new Nesterovs((ISchedule)new StepSchedule(ScheduleType.ITERATION, 0.01, 0.96, 320000.0), 0.9)).biasUpdater((IUpdater)new Nesterovs((ISchedule)new StepSchedule(ScheduleType.ITERATION, 0.02, 0.96, 320000.0), 0.9)).weightInit(WeightInit.XAVIER).l2(2.0E-4).graphBuilder();
        graph.addInputs(new String[]{"input"}).addLayer("cnn1", (Layer)this.conv7x7(this.inputShape[0], 64, 0.2), new String[]{"input"}).addLayer("max1", (Layer)new SubsamplingLayer.Builder(new int[]{3, 3}, new int[]{2, 2}, new int[]{0, 0}).build(), new String[]{"cnn1"}).addLayer("lrn1", (Layer)new LocalResponseNormalization.Builder(5.0, 1.0E-4, 0.75).build(), new String[]{"max1"}).addLayer("cnn2", (Layer)this.conv1x1(64, 64, 0.2), new String[]{"lrn1"}).addLayer("cnn3", (Layer)this.conv3x3(64, 192, 0.2), new String[]{"cnn2"}).addLayer("lrn2", (Layer)new LocalResponseNormalization.Builder(5.0, 1.0E-4, 0.75).build(), new String[]{"cnn3"}).addLayer("max2", (Layer)new SubsamplingLayer.Builder(new int[]{3, 3}, new int[]{2, 2}, new int[]{0, 0}).build(), new String[]{"lrn2"});
        this.inception(graph, "3a", 192, new int[][]{{64}, {96, 128}, {16, 32}, {32}}, "max2");
        this.inception(graph, "3b", 256, new int[][]{{128}, {128, 192}, {32, 96}, {64}}, "3a-depthconcat1");
        graph.addLayer("max3", (Layer)new SubsamplingLayer.Builder(new int[]{3, 3}, new int[]{2, 2}, new int[]{0, 0}).build(), new String[]{"3b-depthconcat1"});
        this.inception(graph, "4a", 480, new int[][]{{192}, {96, 208}, {16, 48}, {64}}, "3b-depthconcat1");
        this.inception(graph, "4b", 512, new int[][]{{160}, {112, 224}, {24, 64}, {64}}, "4a-depthconcat1");
        this.inception(graph, "4c", 512, new int[][]{{128}, {128, 256}, {24, 64}, {64}}, "4b-depthconcat1");
        this.inception(graph, "4d", 512, new int[][]{{112}, {144, 288}, {32, 64}, {64}}, "4c-depthconcat1");
        this.inception(graph, "4e", 528, new int[][]{{256}, {160, 320}, {32, 128}, {128}}, "4d-depthconcat1");
        graph.addLayer("max4", (Layer)new SubsamplingLayer.Builder(new int[]{3, 3}, new int[]{2, 2}, new int[]{0, 0}).build(), new String[]{"4e-depthconcat1"});
        this.inception(graph, "5a", 832, new int[][]{{256}, {160, 320}, {32, 128}, {128}}, "max4");
        this.inception(graph, "5b", 832, new int[][]{{384}, {192, 384}, {48, 128}, {128}}, "5a-depthconcat1");
        graph.addLayer("avg3", (Layer)this.avgPool7x7(1), new String[]{"5b-depthconcat1"}).addLayer("fc1", (Layer)this.fullyConnected(1024, 1024, 0.4), new String[]{"avg3"}).addLayer("output", (Layer)((OutputLayer.Builder)((OutputLayer.Builder)((OutputLayer.Builder)new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD).nIn(1024)).nOut(this.numLabels)).activation(Activation.SOFTMAX)).build(), new String[]{"fc1"}).setOutputs(new String[]{"output"}).backprop(true).pretrain(false);
        return graph.build();
    }

    public ComputationGraph init() {
        ComputationGraph model = new ComputationGraph(this.conf());
        model.init();
        return model;
    }

    @Override
    public ModelMetaData metaData() {
        return new ModelMetaData(new int[][]{this.inputShape}, 1, ZooType.CNN);
    }

    @Override
    public void setInputShape(int[][] inputShape) {
        this.inputShape = inputShape[0];
    }

    public GoogLeNet() {
    }
}

