diff --git a/dataflow/gemma/custom_model_gemma.py b/dataflow/gemma/custom_model_gemma.py index fbf0b97505..7b7168f1e5 100644 --- a/dataflow/gemma/custom_model_gemma.py +++ b/dataflow/gemma/custom_model_gemma.py @@ -74,13 +74,24 @@ def run_inference( Returns: An Iterable of type PredictionResult. """ + if inference_args is None: + inference_args = {"max_length": 64} # Loop each text string, and use a tuple to store the inference results. predictions = [] for one_text in batch: - result = model.generate(one_text, max_length=64) + result = model.generate(one_text, **inference_args) predictions.append(result) return utils._convert_to_result(batch, predictions, self._model_name) + def validate_inference_args(self, inference_args: Optional[dict[str, + Any]]): + if inference_args: + if len(inference_args + ) > 1 or "max_length" not in inference_args.keys(): + raise ValueError( + "invalid inference args, only valid arg is max_length, got", + inference_args) + class FormatOutput(beam.DoFn): def process(self, element, *args, **kwargs): @@ -123,8 +134,10 @@ def process(self, element, *args, **kwargs): beam.io.ReadFromPubSub(subscription=args.messages_subscription) | "Parse" >> beam.Map(lambda x: x.decode("utf-8")) | "RunInference-Gemma" >> RunInference( - GemmaModelHandler(args.model_path) - ) # Send the prompts to the model and get responses. + GemmaModelHandler(args.model_path), + inference_args={ + "max_length": 32 + }) # Send the prompts to the model and get responses. | "Format Output" >> beam.ParDo(FormatOutput()) # Format the output. | "Publish Result" >> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)) diff --git a/dataflow/gemma/requirements.txt b/dataflow/gemma/requirements.txt index 76fc60632e..8d205dfc54 100644 --- a/dataflow/gemma/requirements.txt +++ b/dataflow/gemma/requirements.txt @@ -1,4 +1,5 @@ apache_beam[gcp]==2.54.0 protobuf==4.25.0 keras_nlp==0.8.2 -keras==3.0.5 \ No newline at end of file +keras==3.0.5 +tensorflow==2.16.1 \ No newline at end of file