/*
 * Decompiled with CFR 0.152.
 */
package org.tensorflow;

import com.google.protobuf.InvalidProtocolBufferException;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.TensorFlow;
import org.tensorflow.internal.c_api.TF_Buffer;
import org.tensorflow.internal.c_api.TF_Function;
import org.tensorflow.internal.c_api.TF_Status;
import org.tensorflow.internal.c_api.global.tensorflow;
import org.tensorflow.proto.framework.FunctionDef;
import org.tensorflow.proto.framework.NodeDef;

class NativeFunction {
    private final TF_Function nativeHandle;
    private FunctionDef functionDef = null;
    private List<String> dependencies = null;
    private Boolean stateful = null;
    private String name = null;

    public NativeFunction(TF_Function nativeHandle) {
        this.nativeHandle = nativeHandle;
    }

    public TF_Function getNativeHandle() {
        return this.nativeHandle;
    }

    public synchronized FunctionDef getFunctionDef() {
        if (this.functionDef == null) {
            try (PointerScope scope = new PointerScope();){
                TF_Buffer funcDefBuffer = TF_Buffer.newBuffer();
                TF_Status status = TF_Status.newStatus();
                tensorflow.TF_FunctionToFunctionDef(this.nativeHandle, funcDefBuffer, status);
                status.throwExceptionIfNotOK();
                try {
                    this.functionDef = FunctionDef.parseFrom(funcDefBuffer.dataAsByteBuffer());
                }
                catch (InvalidProtocolBufferException e) {
                    throw new IllegalStateException("Failed to parse FunctionDef proto", e);
                }
            }
        }
        return this.functionDef;
    }

    public synchronized List<String> getDependencies() {
        if (this.dependencies == null) {
            LinkedHashSet deps = new LinkedHashSet();
            for (NodeDef node : this.getFunctionDef().getNodeDefList()) {
                node.getAttrMap().values().forEach(attr -> {
                    if (attr.hasFunc()) {
                        deps.add(attr.getFunc().getName());
                    } else if (attr.hasList()) {
                        attr.getList().getFuncList().forEach(funcs -> deps.add(funcs.getName()));
                    }
                });
            }
            this.dependencies = Collections.unmodifiableList(new ArrayList(deps));
        }
        return this.dependencies;
    }

    public synchronized boolean isStateful() {
        if (this.stateful == null) {
            this.stateful = this.getFunctionDef().getSignature().getIsStateful() || this.getFunctionDef().getNodeDefList().stream().anyMatch(x -> TensorFlow.isOpStateful(x.getOp()));
        }
        return this.stateful;
    }

    public synchronized String getName() {
        if (this.name == null) {
            try (PointerScope scope = new PointerScope();){
                String string = tensorflow.TF_FunctionName(this.nativeHandle).getString();
                return string;
            }
        }
        return this.name;
    }

    synchronized Set<TF_Function> getAllDependencies(Collection<NativeFunction> availableFunctions) {
        Map<String, NativeFunction> fnMap = availableFunctions.stream().collect(Collectors.toMap(NativeFunction::getName, e -> e));
        LinkedHashSet<String> done = new LinkedHashSet<String>(1 + this.getDependencies().size());
        ArrayDeque<NativeFunction> todo = new ArrayDeque<NativeFunction>(1 + this.getDependencies().size());
        todo.add(this);
        while (!todo.isEmpty()) {
            NativeFunction next = (NativeFunction)todo.remove();
            if (!done.add(next.getName())) continue;
            for (String dep : next.getDependencies()) {
                if (done.contains(dep)) continue;
                NativeFunction fn = fnMap.get(dep);
                if (fn == null) {
                    throw new IllegalStateException("Function " + dep + " is required, but not present in graph.");
                }
                todo.add(fn);
            }
        }
        done.remove(this.getName());
        return done.stream().map(fnMap::get).map(NativeFunction::getNativeHandle).collect(Collectors.toSet());
    }
}

