/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff.transform;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.Variable;

public class SubGraph {
    protected SameDiff sameDiff;
    protected DifferentialFunction rootNode;
    protected List<DifferentialFunction> childNodes;

    public List<SDVariable> outputs() {
        ArrayList allOutputs = new ArrayList();
        if (this.rootNode.outputVariables() != null) {
            Collections.addAll(allOutputs, this.rootNode.outputVariables());
        }
        if (this.childNodes != null && !this.childNodes.isEmpty()) {
            HashSet seenAsInput = new HashSet();
            if (this.rootNode.args() != null) {
                Collections.addAll(seenAsInput, this.rootNode.args());
            }
            for (DifferentialFunction df : this.childNodes) {
                if (df.args() != null) {
                    Collections.addAll(seenAsInput, df.args());
                }
                if (df.outputVariables() == null) continue;
                Collections.addAll(allOutputs, df.outputVariables());
            }
        }
        ArrayList<SDVariable> filteredOutputs = new ArrayList<SDVariable>(allOutputs.size());
        for (SDVariable v : allOutputs) {
            Variable var = this.sameDiff.getVariables().get(v.getVarName());
            List<String> inputsFor = var.getInputsForOp();
            boolean allInSubgraph = true;
            if (inputsFor != null) {
                for (String opOwnName : inputsFor) {
                    if (this.inSubgraph(this.sameDiff.getFunctionById(opOwnName))) continue;
                    allInSubgraph = false;
                    break;
                }
            }
            if (allInSubgraph) continue;
            filteredOutputs.add(v);
        }
        return filteredOutputs;
    }

    public List<SDVariable> inputs() {
        HashSet outputsOfSubgraphNodes = new HashSet();
        for (DifferentialFunction df : this.allFunctionsInSubgraph()) {
            SDVariable[] outputVars = df.outputVariables();
            if (outputVars == null) continue;
            Collections.addAll(outputsOfSubgraphNodes, outputVars);
        }
        ArrayList<SDVariable> inputs = new ArrayList<SDVariable>();
        for (DifferentialFunction df : this.allFunctionsInSubgraph()) {
            SDVariable[] args = df.args();
            if (args == null) continue;
            for (SDVariable arg : args) {
                if (outputsOfSubgraphNodes.contains(arg)) continue;
                inputs.add(arg);
            }
        }
        return inputs;
    }

    public boolean inSubgraph(DifferentialFunction df) {
        if (this.rootNode == df) {
            return true;
        }
        if (this.childNodes != null) {
            for (DifferentialFunction d : this.childNodes) {
                if (d != df) continue;
                return true;
            }
        }
        return false;
    }

    public List<DifferentialFunction> allFunctionsInSubgraph() {
        ArrayList<DifferentialFunction> out = new ArrayList<DifferentialFunction>();
        out.add(this.rootNode);
        if (this.childNodes != null) {
            out.addAll(this.childNodes);
        }
        return out;
    }

    public static SubGraphBuilder builder() {
        return new SubGraphBuilder();
    }

    public SubGraph(SameDiff sameDiff, DifferentialFunction rootNode, List<DifferentialFunction> childNodes) {
        this.sameDiff = sameDiff;
        this.rootNode = rootNode;
        this.childNodes = childNodes;
    }

    public SubGraph() {
    }

    public SameDiff getSameDiff() {
        return this.sameDiff;
    }

    public DifferentialFunction getRootNode() {
        return this.rootNode;
    }

    public List<DifferentialFunction> getChildNodes() {
        return this.childNodes;
    }

    public void setSameDiff(SameDiff sameDiff) {
        this.sameDiff = sameDiff;
    }

    public void setRootNode(DifferentialFunction rootNode) {
        this.rootNode = rootNode;
    }

    public void setChildNodes(List<DifferentialFunction> childNodes) {
        this.childNodes = childNodes;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof SubGraph)) {
            return false;
        }
        SubGraph other = (SubGraph)o;
        if (!other.canEqual(this)) {
            return false;
        }
        SameDiff this$sameDiff = this.getSameDiff();
        SameDiff other$sameDiff = other.getSameDiff();
        if (this$sameDiff == null ? other$sameDiff != null : !((Object)this$sameDiff).equals(other$sameDiff)) {
            return false;
        }
        DifferentialFunction this$rootNode = this.getRootNode();
        DifferentialFunction other$rootNode = other.getRootNode();
        if (this$rootNode == null ? other$rootNode != null : !((Object)this$rootNode).equals(other$rootNode)) {
            return false;
        }
        List<DifferentialFunction> this$childNodes = this.getChildNodes();
        List<DifferentialFunction> other$childNodes = other.getChildNodes();
        return !(this$childNodes == null ? other$childNodes != null : !((Object)this$childNodes).equals(other$childNodes));
    }

    protected boolean canEqual(Object other) {
        return other instanceof SubGraph;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        SameDiff $sameDiff = this.getSameDiff();
        result = result * 59 + ($sameDiff == null ? 43 : ((Object)$sameDiff).hashCode());
        DifferentialFunction $rootNode = this.getRootNode();
        result = result * 59 + ($rootNode == null ? 43 : ((Object)$rootNode).hashCode());
        List<DifferentialFunction> $childNodes = this.getChildNodes();
        result = result * 59 + ($childNodes == null ? 43 : ((Object)$childNodes).hashCode());
        return result;
    }

    public String toString() {
        return "SubGraph(sameDiff=" + this.getSameDiff() + ", rootNode=" + this.getRootNode() + ", childNodes=" + this.getChildNodes() + ")";
    }

    public static class SubGraphBuilder {
        private SameDiff sameDiff;
        private DifferentialFunction rootNode;
        private List<DifferentialFunction> childNodes;

        SubGraphBuilder() {
        }

        public SubGraphBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public SubGraphBuilder rootNode(DifferentialFunction rootNode) {
            this.rootNode = rootNode;
            return this;
        }

        public SubGraphBuilder childNodes(List<DifferentialFunction> childNodes) {
            this.childNodes = childNodes;
            return this;
        }

        public SubGraph build() {
            return new SubGraph(this.sameDiff, this.rootNode, this.childNodes);
        }

        public String toString() {
            return "SubGraph.SubGraphBuilder(sameDiff=" + this.sameDiff + ", rootNode=" + this.rootNode + ", childNodes=" + this.childNodes + ")";
        }
    }
}

