/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.model.application.validation.change;

import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.vespa.model.application.validation.Validation;
import com.yahoo.vespa.model.application.validation.change.ChangeValidator;
import com.yahoo.vespa.model.application.validation.change.VespaRestartAction;
import com.yahoo.vespa.model.container.ApplicationContainer;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;

public class RestartOnDeployForOnnxModelChangesValidator
implements ChangeValidator {
    private static final Logger log = Logger.getLogger(RestartOnDeployForOnnxModelChangesValidator.class.getName());

    @Override
    public void validate(Validation.ChangeContext context) {
        if (!context.deployState().featureFlags().restartOnDeployWhenOnnxModelChanges()) {
            return;
        }
        for (ApplicationContainerCluster cluster : context.model().getContainerClusters().values()) {
            ApplicationContainerCluster clusterInCurrentModel = context.previousModel().getContainerClusters().get(cluster.getName());
            if (clusterInCurrentModel == null) continue;
            Map currentModels = clusterInCurrentModel.onnxModelCostCalculator().models();
            Map nextModels = cluster.onnxModelCostCalculator().models();
            if (RestartOnDeployForOnnxModelChangesValidator.enoughMemoryToAvoidRestart(clusterInCurrentModel, cluster, context.deployState().getDeployLogger())) continue;
            log.log(Level.FINE, "Validating %s, current Onnx models:%s, next Onnx models:%s".formatted(cluster, currentModels, nextModels));
            this.validateModelChanges(cluster, currentModels, nextModels).forEach(context::require);
            this.validateSetOfModels(cluster, currentModels, nextModels).forEach(context::require);
        }
    }

    private List<ConfigChangeAction> validateModelChanges(ApplicationContainerCluster cluster, Map<String, OnnxModelCost.ModelInfo> currentModels, Map<String, OnnxModelCost.ModelInfo> nextModels) {
        ArrayList<ConfigChangeAction> actions = new ArrayList<ConfigChangeAction>();
        for (OnnxModelCost.ModelInfo nextModelInfo : nextModels.values()) {
            if (!currentModels.containsKey(nextModelInfo.modelId())) continue;
            this.modelChanged(nextModelInfo, currentModels.get(nextModelInfo.modelId())).ifPresent(change -> {
                String message = "Onnx model '%s' has changed (%s), need to restart services in %s".formatted(nextModelInfo.modelId(), change, cluster);
                RestartOnDeployForOnnxModelChangesValidator.setRestartOnDeployAndAddRestartAction(actions, cluster, message);
            });
        }
        return actions;
    }

    private List<ConfigChangeAction> validateSetOfModels(ApplicationContainerCluster cluster, Map<String, OnnxModelCost.ModelInfo> currentModels, Map<String, OnnxModelCost.ModelInfo> nextModels) {
        ArrayList<ConfigChangeAction> actions = new ArrayList<ConfigChangeAction>();
        Set<String> currentModelIds = currentModels.keySet();
        Set<String> nextModelIds = nextModels.keySet();
        log.log(Level.FINE, "Checking if Onnx model set has changed (%s) -> (%s)".formatted(currentModelIds, nextModelIds));
        if (!currentModelIds.equals(nextModelIds)) {
            String message = "Onnx model set has changed from %s to %s, need to restart services in %s".formatted(currentModelIds, nextModelIds, cluster);
            RestartOnDeployForOnnxModelChangesValidator.setRestartOnDeployAndAddRestartAction(actions, cluster, message);
        }
        return actions;
    }

    private Optional<String> modelChanged(OnnxModelCost.ModelInfo a, OnnxModelCost.ModelInfo b) {
        log.log(Level.FINE, "Checking if model has changed (%s) -> (%s)".formatted(a, b));
        if (a.estimatedCost() != b.estimatedCost()) {
            return Optional.of("estimated cost");
        }
        if (a.hash() != b.hash()) {
            return Optional.of("model hash");
        }
        if (!a.onnxModelOptions().equals((Object)b.onnxModelOptions())) {
            return Optional.of("model option(s)");
        }
        return Optional.empty();
    }

    private static void setRestartOnDeployAndAddRestartAction(List<ConfigChangeAction> actions, ApplicationContainerCluster cluster, String message) {
        log.log(Level.INFO, message);
        cluster.onnxModelCostCalculator().setRestartOnDeploy();
        cluster.onnxModelCostCalculator().store();
        actions.add(new VespaRestartAction(cluster.id(), message));
    }

    private static boolean enoughMemoryToAvoidRestart(ApplicationContainerCluster clusterInCurrentModel, ApplicationContainerCluster cluster, DeployLogger deployLogger) {
        double currentModelCostInGb = RestartOnDeployForOnnxModelChangesValidator.onnxModelCostInGb(clusterInCurrentModel);
        double nextModelCostInGb = RestartOnDeployForOnnxModelChangesValidator.onnxModelCostInGb(cluster);
        double totalMemory = ((ApplicationContainer)cluster.getContainers().get(0)).getHostResource().realResources().memoryGb();
        double memoryUsedByModels = currentModelCostInGb + nextModelCostInGb;
        double availableMemory = Math.max(0.0, totalMemory - 0.7 - memoryUsedByModels);
        int availableMemoryPercentage = cluster.availableMemoryPercentage();
        int memoryPercentage = (int)(availableMemory / totalMemory * (double)availableMemoryPercentage);
        String prefix = "Validating Onnx models memory usage for %s".formatted(cluster);
        if (memoryPercentage < 15) {
            deployLogger.log(Level.INFO, "%s, percentage of available memory too low (%d < %d) to avoid restart, consider a flavor with more memory to avoid this".formatted(prefix, memoryPercentage, 15));
            return false;
        }
        if (availableMemory < 0.6) {
            deployLogger.log(Level.INFO, "%s, available memory too low (%.2f Gb < %.2f Gb) to avoid restart, consider a flavor with more memory to avoid this".formatted(prefix, availableMemory, 0.6));
            return false;
        }
        log.log(Level.FINE, "%s, enough available memory (%.2f Gb) to avoid restart (models use %.2f Gb)".formatted(prefix, availableMemory, memoryUsedByModels));
        return true;
    }

    private static double onnxModelCostInGb(ApplicationContainerCluster clusterInCurrentModel) {
        return (double)clusterInCurrentModel.onnxModelCostCalculator().aggregatedModelCostInBytes() / 1024.0 / 1024.0 / 1024.0;
    }
}

