Class OrtGptTranslator

java.lang.Object
ai.djl.onnxruntime.zoo.nlp.textgeneration.OrtGptTranslator
All Implemented Interfaces:
ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>, ai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>, ai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>, ai.djl.translate.Translator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>

public class OrtGptTranslator extends Object implements ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>
The Translator for PyTorch GPT2 model.
  • Constructor Summary

    Constructors
    Constructor
    Description
    OrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers)
    Constructs a new instance of PtGptTranslator.
  • Method Summary

    Modifier and Type
    Method
    Description
    ai.djl.ndarray.NDList
    processInput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList input)
    ai.djl.modality.nlp.generate.CausalLMOutput
    processOutput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList output)

    Methods inherited from class java.lang.Object

    clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait

    Methods inherited from interface ai.djl.translate.NoBatchifyTranslator

    getBatchifier

    Methods inherited from interface ai.djl.translate.Translator

    batchProcessInput, batchProcessOutput, getExpansions, prepare
  • Constructor Details

    • OrtGptTranslator

      public OrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers)
      Constructs a new instance of PtGptTranslator.
      Parameters:
      kvDim - the kv dimension
      numAttentionHeads - the number of attention heads
      numLayers - the number of layers
  • Method Details

    • processInput

      public ai.djl.ndarray.NDList processInput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList input) throws Exception
      Specified by:
      processInput in interface ai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>
      Throws:
      Exception
    • processOutput

      public ai.djl.modality.nlp.generate.CausalLMOutput processOutput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList output) throws Exception
      Specified by:
      processOutput in interface ai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>
      Throws:
      Exception