Skip to content

[Bug - Translation server] - Missing tgtparam in translator.translate method (allows some multilingual/seq2seq models to work properly) #2586

@medfreeman

Description

@medfreeman

Some multilingual/seq2seq models such as M2M100 (c.f. Generation section in the linked page) require the bos_token set to the target language id in the sequence tgt property.

In the case of the translation server, to be able to specify the requested translation language, we the need to directly manipulate the sequence tgt property prior to translation.

But in its current state the server has a disconnection between the sequence ref/ref_tok (which can be manipulated through tokenizers/processors btw) and tgt string prior to being sent to ctranslate2.

c.f.

"tgt": {"tgt": ref_tok} if ref_tok is not None else None,

Basically the parameter tgt of the self.translator.translate method is never provided.

c.f.

scores, predictions = self.translator.translate(examples)

I successfully implemented a one-line patch that properly passes the parameter through and allows me to do multilingual translation.
It should not have side-effects on other type of models (for which the sequence ref is empty after tokenizing the sequence), by setting the parameter as an empty string in those cases.

Here’s the PR: #2585

Example of multilingual translation with a M2M100 model:

conf.json

{
    "models_root": "./available_models",
    "models": [
        {
            "id": 100,
            "model": "m2m-multi4-ft-ck945k/",
            "ct2_model": "m2m-multi4-ft-ck945k/",
            "load": true,
            "on_timeout": "unload",
            "ct2_translator_args": {
                "inter_threads": 4,
                "intra_threads": 2
            },
            "ct2_translate_batch_args": {},
            "opt": {
                "beam_size": 1,
                "batch_size": 8,
                "tgt_file_prefix": true
            },
            "preprocess": ["available_models.m2m-multi4-ft-ck945k.tokenizer.m2m100_tokenizer.preprocess"],
            "postprocess": ["available_models.m2m-multi4-ft-ck945k.tokenizer.m2m100_tokenizer.postprocess"]
        }
    ]
}

available_models/m2m-multi4-ft-ck945k/tokenizer/m2m100_tokenizer.py

import os
from pathlib import Path
from transformers import M2M100Tokenizer

cache = None

def loadTokenizer(model_root, logger):
        global cache
        if cache is not None:
              return cache

        model_path = os.path.join(model_root, "m2m-multi4-ft-ck945k/tokenizer/")
        logger.info("Loading m2m100 tokenizer from %s", model_path)
        cache = M2M100Tokenizer.from_pretrained(model_path)

        return cache

def preprocess(sequence, server_model):
        """Preprocess a single sequence.

        Args:
            sequence (dict[str, Unknown]): The sequence to preprocess.

        Returns:
            sequence (dict[str, Unknown]): The preprocessed sequence."""
        server_model.logger.info(f"Running preprocessor '{ Path(__file__).stem }'")

        ref = sequence.get("ref", None)
        if ref[0] is not None:
            server_model.logger.debug(f"${ref[0]=}")
            tgt_lang = ref[0].get("tgt_lang", None)
            if tgt_lang is not None:
                server_model.logger.debug(f"${tgt_lang=}")

                tokenizer = loadTokenizer(server_model.model_root, server_model.logger)

                seg = sequence.get("seg", None)
                tok = tokenizer.convert_ids_to_tokens(
                    tokenizer.encode(seg[0])
                )
                tok = " ".join(tok)

                sequence["seg"][0] = tok

                lang_prefix = f"__{tgt_lang}__"
                sequence["ref"][0] = f"{lang_prefix}"
                server_model.logger.info(f"Added lang prefix to ref: '{lang_prefix}'")
                server_model.logger.debug(f"${sequence['ref'][0]=}")

        return sequence

def postprocess(sequence, server_model):
        """Postprocess a single sequence.

        Args:
            sequence (dict[str, Unknown]): The sequence to postprocess.

        Returns:
            sequence (dict[str, Unknown]): The post processed sequence."""
        server_model.logger.info(f"Running postprocessor '{ Path(__file__).stem }'")

        tokenizer = loadTokenizer(server_model.model_root, server_model.logger)

        seg = sequence.get("seg", None)
        detok = tokenizer.decode(
            tokenizer.convert_tokens_to_ids(seg[0].split()[1:]),
            skip_special_tokens=True
        )
        return detok

Sample request to server:

[
    {
        "src": "Brian is in the kitchen.",
        "id": 100,
        "ref": {
            "src_lang": "en",
            "tgt_lang": "fr"
        }
    },
    {
        "src": "By the way, do you like to eat pancakes?",
        "id": 100,
        "ref": {
            "src_lang": "en",
            "tgt_lang": "fr"
        }
    }
]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions