/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.optimize;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.nd4j.autodiff.samediff.ArrayHolder;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.optimize.OptimizationConfig;
import org.nd4j.autodiff.samediff.optimize.OptimizationHelper;
import org.nd4j.autodiff.samediff.optimize.Optimizer;
import org.nd4j.autodiff.samediff.optimize.OptimizerSet;
import org.nd4j.autodiff.samediff.optimize.debug.OptimizationDebugger;
import org.nd4j.autodiff.samediff.optimize.optimizations.ConstantFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.CuDNNFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.IdentityFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.ShapeFunctionOptimizations;
import org.nd4j.autodiff.samediff.optimize.optimizations.UnusedFunctionOptimizations;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GraphOptimizer {
    private static final Logger log = LoggerFactory.getLogger(GraphOptimizer.class);

    public static List<OptimizerSet> defaultOptimizations() {
        return Arrays.asList(new UnusedFunctionOptimizations(), new ConstantFunctionOptimizations(), new IdentityFunctionOptimizations(), new ShapeFunctionOptimizations(), new UnusedFunctionOptimizations(), new CuDNNFunctionOptimizations());
    }

    public static SameDiff optimize(SameDiff graph, String ... requiredOutputs) {
        return GraphOptimizer.optimize(graph, Arrays.asList(requiredOutputs));
    }

    public static SameDiff optimize(SameDiff graph, List<String> requiredOutputs) {
        return GraphOptimizer.optimize(graph, requiredOutputs, GraphOptimizer.defaultOptimizations());
    }

    public static SameDiff optimize(SameDiff graph, List<String> requiredOutputs, List<OptimizerSet> optimizations) {
        return GraphOptimizer.optimize(graph, requiredOutputs, optimizations, null);
    }

    public static SameDiff optimize(SameDiff graph, List<String> requiredOutputs, List<OptimizerSet> optimizations, OptimizationDebugger debugger) {
        SameDiff sd = graph.dup();
        ArrayHolder cArr = sd.getConstantArrays();
        ArrayHolder vArr = sd.getVariablesArrays();
        OptimizationHelper h = new OptimizationHelper(graph, new OptimizationConfig());
        for (int i = 0; i < 3; ++i) {
            for (OptimizerSet s : optimizations) {
                List<Optimizer> l = s.getOptimizers();
                for (Optimizer o : l) {
                    ArrayList<SameDiffOp> startingOps = new ArrayList<SameDiffOp>(sd.getOps().values());
                    for (SameDiffOp op : startingOps) {
                        boolean applied;
                        if (!sd.getOps().containsKey(op.getName())) continue;
                        if (debugger != null) {
                            debugger.beforeOptimizationCheck(sd, op, o);
                        }
                        if (applied = o.checkAndApply(sd, h, op, cArr, vArr)) {
                            log.info("Operation was applied: {}", (Object)o);
                        }
                        if (debugger == null) continue;
                        debugger.afterOptimizationsCheck(sd, op, o, applied);
                    }
                }
            }
        }
        int constBefore = 0;
        int constAfter = 0;
        int varBefore = 0;
        int varAfter = 0;
        int arrBefore = 0;
        int arrAfter = 0;
        for (SDVariable v : graph.variables()) {
            switch (v.getVariableType()) {
                case VARIABLE: {
                    ++varBefore;
                    break;
                }
                case CONSTANT: {
                    ++constBefore;
                    break;
                }
                case ARRAY: {
                    ++arrBefore;
                    break;
                }
            }
        }
        for (SDVariable v : sd.variables()) {
            switch (v.getVariableType()) {
                case VARIABLE: {
                    ++varAfter;
                    break;
                }
                case CONSTANT: {
                    ++constAfter;
                    break;
                }
                case ARRAY: {
                    ++arrAfter;
                    break;
                }
            }
        }
        log.info("Total variables: {} before, {} after", (Object)graph.getVariables().size(), (Object)sd.getVariables().size());
        log.info("Constant variables: {} before, {} after", (Object)constBefore, (Object)constAfter);
        log.info("Array type variables: {} before, {} after", (Object)arrBefore, (Object)arrAfter);
        log.info("Variable type variables: {} before, {} after", (Object)varBefore, (Object)varAfter);
        log.info("Ops: {} before, {} after", (Object)graph.getOps().size(), (Object)sd.getOps().size());
        return sd;
    }
}

