Source code for construe.models.tflite

"""
Utilities for model loading and conversion.
"""

import tensorflow as tf


[docs] class TFLiteGenerateModel(tf.Module): """ Defines a model whose serving function is the generation call. """ def __init__(self, model): super(TFLiteGenerateModel, self).__init__() self.model = model
[docs] @tf.function( # shouldn't need static batch size, but throws exception without it (needs to be fixed) input_signature=[ tf.TensorSpec((1, 80, 3000), tf.float32, name="input_features"), ], ) def serving(self, input_features): outputs = self.model.generate( input_features, max_new_tokens=255, return_dict_in_generate=True, ) return {"sequences": outputs["sequences"]}