/*
 * Decompiled with CFR 0.152.
 */
package org.mlflow.tracking;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import org.mlflow.api.proto.Service;
import org.mlflow.tracking.ActiveRun;
import org.mlflow.tracking.MlflowClient;
import org.mlflow.tracking.utils.DatabricksContext;

public class MlflowContext {
    private MlflowClient client;
    private String experimentId;

    public MlflowContext() {
        this(new MlflowClient());
    }

    public MlflowContext(String trackingUri) {
        this(new MlflowClient(trackingUri));
    }

    public MlflowContext(MlflowClient client) {
        this.client = client;
        this.experimentId = MlflowContext.getDefaultExperimentId();
    }

    public MlflowClient getClient() {
        return this.client;
    }

    public MlflowContext setExperimentName(String experimentName) throws IllegalArgumentException {
        Optional<Service.Experiment> experimentOpt = this.client.getExperimentByName(experimentName);
        if (!experimentOpt.isPresent()) {
            throw new IllegalArgumentException(String.format("%s is not a valid experiment", experimentName));
        }
        this.experimentId = experimentOpt.get().getExperimentId();
        return this;
    }

    public MlflowContext setExperimentId(String experimentId) {
        this.experimentId = experimentId;
        return this;
    }

    public String getExperimentId() {
        return this.experimentId;
    }

    public ActiveRun startRun() {
        return this.startRun(null);
    }

    public ActiveRun startRun(String runName) {
        return this.startRun(runName, null);
    }

    public ActiveRun startRun(String runName, String parentRunId) {
        DatabricksContext databricksContext;
        HashMap<String, String> tags = new HashMap<String, String>();
        if (runName != null) {
            tags.put("mlflow.runName", runName);
        }
        tags.put("mlflow.user", System.getProperty("user.name"));
        tags.put("mlflow.source.type", "LOCAL");
        if (parentRunId != null) {
            tags.put("mlflow.parentRunId", parentRunId);
        }
        if ((databricksContext = DatabricksContext.createIfAvailable()) != null) {
            tags.putAll(databricksContext.getTags());
        }
        Service.CreateRun.Builder createRunBuilder = Service.CreateRun.newBuilder().setExperimentId(this.experimentId).setStartTime(System.currentTimeMillis());
        for (Map.Entry tag : tags.entrySet()) {
            createRunBuilder.addTags(Service.RunTag.newBuilder().setKey((String)tag.getKey()).setValue((String)tag.getValue()).build());
        }
        Service.RunInfo runInfo = this.client.createRun(createRunBuilder.build());
        ActiveRun newRun = new ActiveRun(runInfo, this.client);
        return newRun;
    }

    public void withActiveRun(Consumer<ActiveRun> activeRunFunction) {
        ActiveRun newRun = this.startRun();
        try {
            activeRunFunction.accept(newRun);
        }
        catch (Exception e) {
            newRun.endRun(Service.RunStatus.FAILED);
            return;
        }
        newRun.endRun(Service.RunStatus.FINISHED);
    }

    public void withActiveRun(String runName, Consumer<ActiveRun> activeRunFunction) {
        ActiveRun newRun = this.startRun(runName);
        try {
            activeRunFunction.accept(newRun);
        }
        catch (Exception e) {
            newRun.endRun(Service.RunStatus.FAILED);
            return;
        }
        newRun.endRun(Service.RunStatus.FINISHED);
    }

    private static String getDefaultExperimentId() {
        String notebookId;
        DatabricksContext databricksContext = DatabricksContext.createIfAvailable();
        if (databricksContext != null && databricksContext.isInDatabricksNotebook() && (notebookId = databricksContext.getNotebookId()) != null) {
            return notebookId;
        }
        return "0";
    }
}

