/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.repository.zoo;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.modality.Input;
import ai.djl.modality.Output;
import ai.djl.ndarray.NDList;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.Resource;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelLoader;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.NoopTranslator;
import ai.djl.translate.ServingTranslatorFactory;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.lang.reflect.Type;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;

public abstract class BaseModelLoader<I, O>
implements ModelLoader<I, O> {
    protected Map<Pair<Type, Type>, TranslatorFactory<?, ?>> factories;
    protected ModelZoo modelZoo;
    protected Resource resource;

    protected BaseModelLoader(Repository repository, MRL mrl, String version, ModelZoo modelZoo) {
        this.resource = new Resource(repository, mrl, version);
        this.modelZoo = modelZoo;
        this.factories = new ConcurrentHashMap();
        this.factories.put(new Pair<Class<NDList>, Class<NDList>>(NDList.class, NDList.class), (m, c) -> new NoopTranslator());
        this.factories.put(new Pair<Class<Input>, Class<Output>>(Input.class, Output.class), new ServingTranslatorFactory());
    }

    @Override
    public String getArtifactId() {
        return this.resource.getMrl().getArtifactId();
    }

    @Override
    public <S, T> ZooModel<S, T> loadModel(Criteria<S, T> criteria) throws IOException, ModelNotFoundException, MalformedModelException {
        Artifact artifact = this.resource.match(criteria.getFilters());
        if (artifact == null) {
            throw new ModelNotFoundException("Model not found.");
        }
        Map<String, Object> override = criteria.getArguments();
        Progress progress = criteria.getProgress();
        Map<String, Object> arguments = artifact.getArguments(override);
        try {
            TranslatorFactory<S, T> factory = criteria.getTranslatorFactory();
            if (factory == null && (factory = this.getTranslatorFactory(criteria)) == null) {
                throw new ModelNotFoundException("No matching default translator found.");
            }
            this.resource.prepare(artifact, progress);
            if (progress != null) {
                progress.reset("Loading", 2L);
                progress.update(1L);
            }
            Path modelPath = this.resource.getRepository().getResourceDirectory(artifact);
            String engine = criteria.getEngine();
            if (engine == null && this.modelZoo != null) {
                String defaultEngine = Engine.getInstance().getEngineName();
                for (String supportedEngine : this.modelZoo.getSupportedEngines()) {
                    if (supportedEngine.equals(defaultEngine)) {
                        engine = supportedEngine;
                        break;
                    }
                    if (!Engine.hasEngine(supportedEngine)) continue;
                    engine = supportedEngine;
                }
                if (engine == null) {
                    throw new ModelNotFoundException("No supported engine available for model zoo: " + this.modelZoo.getGroupId());
                }
            }
            if (engine != null && !Engine.hasEngine(engine)) {
                throw new ModelNotFoundException(engine + " is not supported.");
            }
            String modelName = criteria.getModelName();
            if (modelName == null) {
                modelName = artifact.getName();
            }
            Model model = this.createModel(modelName, criteria.getDevice(), artifact, arguments, engine);
            if (criteria.getBlock() != null) {
                model.setBlock(criteria.getBlock());
            }
            model.load(modelPath, null, criteria.getOptions());
            Application application = criteria.getApplication();
            if (application != Application.UNDEFINED) {
                arguments.put("application", application.getPath());
            }
            Translator<S, T> translator = factory.newInstance(model, arguments);
            ZooModel<S, T> zooModel = new ZooModel<S, T>(model, translator);
            return zooModel;
        }
        catch (TranslateException e) {
            throw new ModelNotFoundException("No matching translator found.", e);
        }
        finally {
            if (progress != null) {
                progress.end();
            }
        }
    }

    @Override
    public List<Artifact> listModels() throws IOException {
        List<Artifact> list = this.resource.listArtifacts();
        String version = this.resource.getVersion();
        return list.stream().filter(a -> version == null || version.equals(a.getVersion())).collect(Collectors.toList());
    }

    protected Model createModel(String name, Device device, Artifact artifact, Map<String, Object> arguments, String engine) throws IOException {
        return Model.newInstance(name, device, engine);
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append(this.resource.getRepository().getName()).append(':').append(this.resource.getMrl().getGroupId()).append(':').append(this.resource.getMrl().getArtifactId()).append(" [\n");
        try {
            for (Artifact artifact : this.listModels()) {
                sb.append('\t').append(artifact).append('\n');
            }
        }
        catch (IOException e) {
            sb.append("\tFailed load metadata.");
        }
        sb.append("\n]");
        return sb.toString();
    }

    private <S, T> TranslatorFactory<S, T> getTranslatorFactory(Criteria<S, T> criteria) {
        return this.factories.get(new Pair<Class<S>, Class<T>>(criteria.getInputClass(), criteria.getOutputClass()));
    }
}

