/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.model;

import com.yahoo.config.FileReference;
import com.yahoo.config.ModelReference;
import com.yahoo.config.UrlReference;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.config.application.api.DeployLogger;
import com.yahoo.config.model.api.OnnxModelCost;
import com.yahoo.vespa.model.ml.OnnxModelProbe;
import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Level;

public class DefaultOnnxModelCost
implements OnnxModelCost {
    public OnnxModelCost.Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) {
        return new CalculatorImpl(appPkg, logger);
    }

    private static class CalculatorImpl
    implements OnnxModelCost.Calculator {
        private final DeployLogger log;
        private final ApplicationPackage appPkg;
        private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<String, Long>();

        private CalculatorImpl(ApplicationPackage appPkg, DeployLogger log) {
            this.appPkg = appPkg;
            this.log = log;
        }

        public long aggregatedModelCostInBytes() {
            return this.modelCost.values().stream().mapToLong(Long::longValue).sum();
        }

        public void registerModel(ApplicationFile f) {
            String path = f.getPath().getRelative();
            if (this.alreadyAnalyzed(path)) {
                return;
            }
            this.log.log(Level.FINE, () -> "Register model '%s'".formatted(path));
            if (f.exists() && this.appPkg != null) {
                OnnxModelProbe.MemoryStats memoryStats = OnnxModelProbe.probeMemoryStats(this.appPkg, f.getPath()).orElse(null);
                if (memoryStats != null) {
                    this.log.log(Level.FINE, () -> "Register model '%s' with memory stats: %s".formatted(path, memoryStats));
                    this.deductJvmHeapSizeWithModelCost(f.getSize(), memoryStats, path);
                } else {
                    this.deductJvmHeapSizeWithModelCost(f.getSize(), path);
                }
            } else {
                this.deductJvmHeapSizeWithModelCost(0L, path);
            }
        }

        public void registerModel(ModelReference ref) {
            this.log.log(Level.FINE, () -> "Register model '%s'".formatted(ref.toString()));
            if (ref.path().isPresent()) {
                Path path = Paths.get(((FileReference)ref.path().get()).value(), new String[0]);
                String source = path.getFileName().toString();
                if (this.alreadyAnalyzed(source)) {
                    return;
                }
                this.deductJvmHeapSizeWithModelCost((Long)Exceptions.uncheck(() -> Files.exists(path, new LinkOption[0]) ? Files.size(path) : 0L), source);
            } else if (ref.url().isPresent()) {
                this.deductJvmHeapSizeWithModelCost(URI.create(((UrlReference)ref.url().get()).value()));
            } else {
                throw new IllegalStateException(ref.toString());
            }
        }

        private void deductJvmHeapSizeWithModelCost(URI uri) {
            if (this.alreadyAnalyzed(uri.toString())) {
                return;
            }
            if (uri.getScheme().equals("http") || uri.getScheme().equals("https")) {
                try {
                    Duration timeout = Duration.ofSeconds(3L);
                    HttpClient httpClient = HttpClient.newBuilder().connectTimeout(timeout).build();
                    HttpRequest request = HttpRequest.newBuilder(uri).timeout(timeout).method("HEAD", HttpRequest.BodyPublishers.noBody()).build();
                    HttpResponse<Void> response = httpClient.send(request, HttpResponse.BodyHandlers.discarding());
                    String contentLength = response.headers().firstValue("Content-Length").orElse("0");
                    this.log.log(Level.FINE, () -> "Got content length '%s' for '%s'".formatted(contentLength, uri));
                    this.deductJvmHeapSizeWithModelCost(Long.parseLong(contentLength), uri.toString());
                }
                catch (IOException | IllegalArgumentException | InterruptedException e) {
                    this.log.log(Level.INFO, () -> "Failed to get model size for '%s': %s".formatted(uri, e.getMessage()), (Throwable)e);
                }
            }
        }

        private void deductJvmHeapSizeWithModelCost(long size, String source) {
            long fallbackModelSize = 0x40000000L;
            long estimatedCost = Math.max(314572800L, (long)(1.4 * (double)(size > 0L ? size : fallbackModelSize) + 1.048576E8));
            this.log.log(Level.FINE, () -> "Estimated %s footprint for model of size %s ('%s')".formatted(CalculatorImpl.mb(estimatedCost), CalculatorImpl.mb(size), source));
            this.modelCost.put(source, estimatedCost);
        }

        private void deductJvmHeapSizeWithModelCost(long size, OnnxModelProbe.MemoryStats stats, String source) {
            long estimatedCost = (long)(1.1 * (double)stats.vmSize());
            this.log.log(Level.FINE, () -> "Estimated %s footprint for model of size %s ('%s')".formatted(CalculatorImpl.mb(estimatedCost), CalculatorImpl.mb(size), source));
            this.modelCost.put(source, estimatedCost);
        }

        private boolean alreadyAnalyzed(String source) {
            return this.modelCost.containsKey(source);
        }

        private static String mb(long bytes) {
            return "%dMB".formatted(bytes / 0x100000L);
        }
    }
}

