Class OrtGptTranslator
java.lang.Object
ai.djl.onnxruntime.zoo.nlp.textgeneration.OrtGptTranslator
- All Implemented Interfaces:
ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,,ai.djl.modality.nlp.generate.CausalLMOutput> ai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>,ai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>,ai.djl.translate.Translator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>
public class OrtGptTranslator
extends Object
implements ai.djl.translate.NoBatchifyTranslator<ai.djl.ndarray.NDList,ai.djl.modality.nlp.generate.CausalLMOutput>
The
Translator for PyTorch GPT2 model.-
Constructor Summary
ConstructorsConstructorDescriptionOrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) Constructs a new instance ofPtGptTranslator. -
Method Summary
Modifier and TypeMethodDescriptionai.djl.ndarray.NDListprocessInput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList input) ai.djl.modality.nlp.generate.CausalLMOutputprocessOutput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList output) Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, waitMethods inherited from interface ai.djl.translate.NoBatchifyTranslator
getBatchifierMethods inherited from interface ai.djl.translate.Translator
batchProcessInput, batchProcessOutput, getExpansions, prepare
-
Constructor Details
-
OrtGptTranslator
public OrtGptTranslator(long kvDim, int numAttentionHeads, int numLayers) Constructs a new instance ofPtGptTranslator.- Parameters:
kvDim- the kv dimensionnumAttentionHeads- the number of attention headsnumLayers- the number of layers
-
-
Method Details
-
processInput
public ai.djl.ndarray.NDList processInput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList input) throws Exception - Specified by:
processInputin interfaceai.djl.translate.PreProcessor<ai.djl.ndarray.NDList>- Throws:
Exception
-
processOutput
public ai.djl.modality.nlp.generate.CausalLMOutput processOutput(ai.djl.translate.TranslatorContext ctx, ai.djl.ndarray.NDList output) throws Exception - Specified by:
processOutputin interfaceai.djl.translate.PostProcessor<ai.djl.modality.nlp.generate.CausalLMOutput>- Throws:
Exception
-