/*
 * Decompiled with CFR 0.152.
 */
package oracle.pgx.api.mllib;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.BiFunction;
import java.util.function.Supplier;
import oracle.pgx.api.PgxFuture;
import oracle.pgx.api.PgxGraph;
import oracle.pgx.api.PgxSession;
import oracle.pgx.api.PgxVertex;
import oracle.pgx.api.VertexProperty;
import oracle.pgx.api.frames.PgxFrame;
import oracle.pgx.api.internal.Core;
import oracle.pgx.api.internal.Graph;
import oracle.pgx.api.internal.mllib.GnnExplanationMetaData;
import oracle.pgx.api.internal.mllib.GraphWiseModelMetadata;
import oracle.pgx.api.mllib.GnnExplanation;
import oracle.pgx.api.mllib.Model;
import oracle.pgx.common.PgxId;
import oracle.pgx.config.mllib.GraphWiseBaseConvLayerConfig;
import oracle.pgx.config.mllib.GraphWiseModelConfig;

public abstract class GraphWiseModel<Config extends GraphWiseModelConfig, Metadata extends GraphWiseModelMetadata<Config>, ModelType extends GraphWiseModel<Config, Metadata, ModelType>>
extends Model<ModelType> {
    Metadata modelMetadata;
    protected final BiFunction<PgxSession, Graph, PgxGraph> graphConstructor;

    public GraphWiseModel(PgxSession session, Core core, Supplier<String> keystorePathSupplier, Supplier<char[]> keystorePasswordSupplier, Metadata modelMetadata, BiFunction<PgxSession, Graph, PgxGraph> graphConstructor) {
        super(session, core, keystorePathSupplier, keystorePasswordSupplier);
        this.modelMetadata = modelMetadata;
        this.graphConstructor = graphConstructor;
    }

    @Override
    String getModelName() {
        return ((GraphWiseModelMetadata)((Object)this.modelMetadata)).getModelName();
    }

    @Override
    public PgxFuture<Void> destroyAsync() {
        return this.core.destroyMlModel(this.session.getSessionContext(), ((GraphWiseModelMetadata)((Object)this.modelMetadata)).getModelName());
    }

    public void destroy() throws ExecutionException, InterruptedException {
        this.destroyAsync().get();
    }

    public int getNumEpochs() {
        return ((GraphWiseModelMetadata)((Object)this.modelMetadata)).getConfig().getNumEpochs();
    }

    public double getLearningRate() {
        return ((GraphWiseModelMetadata)((Object)this.modelMetadata)).getConfig().getLearningRate();
    }

    public int getBatchSize() {
        return this.getConfig().getBatchSize();
    }

    public int getEmbeddingDim() {
        return this.getConfig().getEmbeddingDim();
    }

    public Integer getSeed() {
        return this.getConfig().getSeed();
    }

    public GraphWiseBaseConvLayerConfig[] getConvLayerConfigs() {
        return this.getConfig().getConvLayerConfigs();
    }

    public List<String> getVertexInputPropertyNames() {
        return this.getConfig().getVertexInputPropertyNames();
    }

    public List<String> getEdgeInputPropertyNames() {
        return this.getConfig().getEdgeInputPropertyNames();
    }

    public boolean isFitted() {
        return this.getConfig().isFitted();
    }

    public double getTrainingLoss() {
        return this.getConfig().getTrainingLoss();
    }

    public int getInputFeatureDim() {
        return this.getConfig().getInputFeatureDim();
    }

    public int getEdgeInputFeatureDim() {
        return this.getConfig().getEdgeInputFeatureDim();
    }

    public Config getConfig() {
        return (Config)((GraphWiseModelMetadata)((Object)this.modelMetadata)).getConfig();
    }

    public abstract PgxFuture<Double> fitAsync(PgxGraph var1);

    public double fit(PgxGraph graph) throws ExecutionException, InterruptedException {
        return this.fitAsync(graph).get();
    }

    public abstract <ID> PgxFuture<PgxFrame> inferEmbeddingsAsync(PgxGraph var1, Iterable<PgxVertex<ID>> var2);

    public <ID> PgxFrame inferEmbeddings(PgxGraph graph, Iterable<PgxVertex<ID>> vertices) {
        return this.inferEmbeddingsAsync(graph, vertices).join();
    }

    protected <ID> List<Object> serializeVertices(Iterable<PgxVertex<ID>> vertices) {
        ArrayList<Object> serializedVertices = new ArrayList<Object>();
        vertices.forEach(v -> serializedVertices.add(v.serialize()));
        return serializedVertices;
    }

    protected <ID> GnnExplanation<ID> processExplanationResult(PgxGraph graph, GnnExplanationMetaData gnnExplanationMetaData) {
        PgxGraph importanceGraph = this.graphConstructor.apply(this.session, gnnExplanationMetaData.getImportanceGraph());
        PgxId importancePropId = gnnExplanationMetaData.getVertexImportancePropertyId();
        VertexProperty vertexImportanceProperty = importanceGraph.getVertexProperties().stream().filter(p -> p.getPropertyId().equals((Object)importancePropId)).findAny().orElseThrow(IllegalStateException::new);
        HashMap featureImportancePropertyMap = new HashMap();
        Map<PgxId, Float> featureImportanceIdMap = gnnExplanationMetaData.getVertexFeatureImportances();
        for (VertexProperty<?, ?> property : graph.getVertexProperties()) {
            PgxId propertyId = property.getPropertyId();
            if (!featureImportanceIdMap.containsKey(propertyId)) continue;
            featureImportancePropertyMap.put(property, featureImportanceIdMap.get(property.getPropertyId()));
        }
        return new GnnExplanation(featureImportancePropertyMap, importanceGraph, vertexImportanceProperty, gnnExplanationMetaData.getEmbedding());
    }
}

