package ai.djl.translate;

import ai.djl.Model;
import ai.djl.util.Pair;
import java.lang.reflect.Type;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/translate/ExpansionTranslatorFactory.class */
public abstract class ExpansionTranslatorFactory<IbaseT, ObaseT> implements TranslatorFactory {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/djl/translate/ExpansionTranslatorFactory$ExpandedTranslatorOptions.class */
    public final class ExpandedTranslatorOptions implements TranslatorOptions {
        private Translator<IbaseT, ObaseT> translator;

        private ExpandedTranslatorOptions(Translator<IbaseT, ObaseT> translator) {
            this.translator = translator;
        }

        @Override // ai.djl.translate.TranslatorOptions
        public Set<Pair<Type, Type>> getOptions() {
            return ExpansionTranslatorFactory.this.getSupportedTypes();
        }

        @Override // ai.djl.translate.TranslatorOptions
        public <I, O> Translator<I, O> option(Class<I> cls, Class<O> cls2) {
            return ExpansionTranslatorFactory.this.newInstance(cls, cls2, this.translator);
        }
    }

    @FunctionalInterface
    /* loaded from: input_file:ai/djl/translate/ExpansionTranslatorFactory$TranslatorExpansion.class */
    public interface TranslatorExpansion<IbaseT, ObaseT> extends Function<Translator<IbaseT, ObaseT>, Translator<?, ?>> {
    }

    @Override // ai.djl.translate.TranslatorFactory
    public Set<Pair<Type, Type>> getSupportedTypes() {
        HashSet hashSet = new HashSet();
        hashSet.addAll(getExpansions().keySet());
        HashSet<Type> hashSet2 = new HashSet();
        hashSet2.addAll(getPreprocessorExpansions().keySet());
        hashSet2.add(getBaseInputType());
        HashSet hashSet3 = new HashSet();
        hashSet3.addAll(getPostprocessorExpansions().keySet());
        hashSet3.add(getBaseOutputType());
        for (Type type : hashSet2) {
            Iterator it = hashSet3.iterator();
            while (it.hasNext()) {
                hashSet.add(new Pair(type, (Type) it.next()));
            }
        }
        return hashSet;
    }

    @Override // ai.djl.translate.TranslatorFactory
    public <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Model model, Map<String, ?> map) {
        return newInstance(cls, cls2, buildBaseTranslator(model, map));
    }

    /* JADX WARN: Multi-variable type inference failed */
    <I, O> Translator<I, O> newInstance(Class<I> cls, Class<O> cls2, Translator<IbaseT, ObaseT> translator) {
        if (cls.equals(getBaseInputType()) && cls2.equals(getBaseOutputType())) {
            return translator;
        }
        TranslatorExpansion<IbaseT, ObaseT> translatorExpansion = getExpansions().get(new Pair(cls, cls2));
        if (translatorExpansion != null) {
            return (Translator) translatorExpansion.apply(translator);
        }
        PreProcessor<?> preProcessor = null;
        if (cls.equals(getBaseInputType())) {
            preProcessor = translator;
        } else {
            Function<PreProcessor<IbaseT>, PreProcessor<?>> function = getPreprocessorExpansions().get(cls);
            if (function != null) {
                preProcessor = function.apply(translator);
            }
        }
        PostProcessor<?> postProcessor = null;
        if (cls2.equals(getBaseOutputType())) {
            postProcessor = translator;
        } else {
            Function<PostProcessor<ObaseT>, PostProcessor<?>> function2 = getPostprocessorExpansions().get(cls2);
            if (function2 != null) {
                postProcessor = function2.apply(translator);
            }
        }
        if (preProcessor == null || postProcessor == null) {
            throw new IllegalArgumentException("Unsupported expansion input/output types.");
        }
        return new BasicTranslator(preProcessor, postProcessor, translator.getBatchifier());
    }

    public ExpansionTranslatorFactory<IbaseT, ObaseT>.ExpandedTranslatorOptions withTranslator(Translator<IbaseT, ObaseT> translator) {
        return new ExpandedTranslatorOptions(translator);
    }

    protected abstract Translator<IbaseT, ObaseT> buildBaseTranslator(Model model, Map<String, ?> map);

    public abstract Class<IbaseT> getBaseInputType();

    public abstract Class<ObaseT> getBaseOutputType();

    protected Map<Pair<Type, Type>, TranslatorExpansion<IbaseT, ObaseT>> getExpansions() {
        return Collections.emptyMap();
    }

    protected Map<Type, Function<PreProcessor<IbaseT>, PreProcessor<?>>> getPreprocessorExpansions() {
        return Collections.singletonMap(getBaseInputType(), preProcessor -> {
            return preProcessor;
        });
    }

    protected Map<Type, Function<PostProcessor<ObaseT>, PostProcessor<?>>> getPostprocessorExpansions() {
        return Collections.singletonMap(getBaseOutputType(), postProcessor -> {
            return postProcessor;
        });
    }
}
