/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.vespa.model.container.component;

import com.yahoo.config.ModelReference;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.embedding.ColBertEmbedderConfig;
import com.yahoo.text.XML;
import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import com.yahoo.vespa.model.container.component.Model;
import com.yahoo.vespa.model.container.component.OnnxEmbedder;
import java.util.Set;
import org.w3c.dom.Element;

public class ColBertEmbedder
extends OnnxEmbedder
implements ColBertEmbedderConfig.Producer {
    private final ModelReference modelRef;
    private final ModelReference vocabRef;
    private final Integer maxQueryTokens;
    private final Integer maxDocumentTokens;
    private final Integer transformerStartSequenceToken;
    private final Integer transformerEndSequenceToken;
    private final Integer transformerMaskToken;
    private final Integer transformerPadToken;
    private final Integer maxTokens;
    private final String transformerInputIds;
    private final String transformerAttentionMask;
    private final Integer queryTokenId;
    private final Integer documentTokenId;
    private final String transformerOutput;

    public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
        super("ai.vespa.embedding.ColBertEmbedder", "model-integration", xml, state);
        Model model = Model.fromXml(state, xml, "transformer-model", Set.of("onnx-model")).orElseThrow();
        this.modelRef = model.modelReference();
        this.vocabRef = Model.fromXmlOrImplicitlyFromOnnxModel(state, xml, model, "tokenizer-model", Set.of("huggingface-tokenizer")).modelReference();
        this.maxTokens = XML.getChildValue((Element)xml, (String)"max-tokens").map(Integer::parseInt).orElse(null);
        this.maxQueryTokens = XML.getChildValue((Element)xml, (String)"max-query-tokens").map(Integer::parseInt).orElse(null);
        this.maxDocumentTokens = XML.getChildValue((Element)xml, (String)"max-document-tokens").map(Integer::parseInt).orElse(null);
        this.transformerStartSequenceToken = XML.getChildValue((Element)xml, (String)"transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
        this.transformerEndSequenceToken = XML.getChildValue((Element)xml, (String)"transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
        this.transformerMaskToken = XML.getChildValue((Element)xml, (String)"transformer-mask-token").map(Integer::parseInt).orElse(null);
        this.transformerPadToken = XML.getChildValue((Element)xml, (String)"transformer-pad-token").map(Integer::parseInt).orElse(null);
        this.queryTokenId = XML.getChildValue((Element)xml, (String)"query-token-id").map(Integer::parseInt).orElse(null);
        this.documentTokenId = XML.getChildValue((Element)xml, (String)"document-token-id").map(Integer::parseInt).orElse(null);
        this.transformerInputIds = XML.getChildValue((Element)xml, (String)"transformer-input-ids").orElse(null);
        this.transformerAttentionMask = XML.getChildValue((Element)xml, (String)"transformer-attention-mask").orElse(null);
        this.transformerOutput = XML.getChildValue((Element)xml, (String)"transformer-output").orElse(null);
        model.registerOnnxModelCost(cluster, this.onnxModelOptions);
    }

    public void getConfig(ColBertEmbedderConfig.Builder b) {
        b.transformerModel(this.modelRef).tokenizerPath(this.vocabRef);
        if (this.maxTokens != null) {
            b.transformerMaxTokens(this.maxTokens.intValue());
        }
        if (this.transformerInputIds != null) {
            b.transformerInputIds(this.transformerInputIds);
        }
        if (this.transformerAttentionMask != null) {
            b.transformerAttentionMask(this.transformerAttentionMask);
        }
        if (this.transformerOutput != null) {
            b.transformerOutput(this.transformerOutput);
        }
        if (this.maxQueryTokens != null) {
            b.maxQueryTokens(this.maxQueryTokens.intValue());
        }
        if (this.maxDocumentTokens != null) {
            b.maxDocumentTokens(this.maxDocumentTokens.intValue());
        }
        if (this.transformerStartSequenceToken != null) {
            b.transformerStartSequenceToken(this.transformerStartSequenceToken.intValue());
        }
        if (this.transformerEndSequenceToken != null) {
            b.transformerEndSequenceToken(this.transformerEndSequenceToken.intValue());
        }
        if (this.transformerMaskToken != null) {
            b.transformerMaskToken(this.transformerMaskToken.intValue());
        }
        if (this.transformerPadToken != null) {
            b.transformerPadToken(this.transformerPadToken.intValue());
        }
        if (this.queryTokenId != null) {
            b.queryTokenId(this.queryTokenId.intValue());
        }
        if (this.documentTokenId != null) {
            b.documentTokenId(this.documentTokenId.intValue());
        }
    }
}

