Skip to content

choices_factory for PrimitiveConstructorSpec #308

@mirceamironenco

Description

@mirceamironenco

There's already a way to do this, so this is just a feature request. I'm curious if you think it's worth adding/makes sense for the constructor spec API. Consider the following setting:

from typing import Annotated

import torch.nn as nn
import tyro

model_registry: dict[str, nn.Module] = {}


def register_model(name: str, model: nn.Module) -> None:
    if name not in model_registry:
        model_registry[name] = model


def registered_models() -> list[str]:
    return list(model_registry.keys())


def is_registered(name: str) -> bool:
    return name in model_registry


class Model1(nn.Module): ...


register_model("model1", Model1())


# Some other file


ModelName = Annotated[
    str,
    tyro.constructors.PrimitiveConstructorSpec(
        nargs=1,
        metavar="{" + ",".join(registered_models()[:3]) + ",...}",
        instance_from_str=lambda args: args[0],
        is_instance=lambda instance: isinstance(instance, str)
        and is_registered(instance),
        str_from_instance=lambda instance: [instance],
        choices=tuple(registered_models()),
    ),
    tyro.conf.arg(
        help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)"
    ),
]


# User defines a new model after ModelName type has been defined
class Model2(nn.Module): ...


register_model("model2", Model2())

# model2 will not show up as a choice
def main(model: ModelName) -> None:
    print(model)


if __name__ == "__main__":
    tyro.cli(main)

If we have some registry system which constructs a set of choices, and would like to also allow the user to add to the existing choices, the PrimitiveConstructorSpec has a limitation where choices= has already been defined, so in the above example model2 will not be a possible choice.

I can already accomplish this with a constructor_factory:

from typing import Annotated

import torch.nn as nn
import tyro

model_registry: dict[str, nn.Module] = {}


def register_model(name: str, model: nn.Module) -> None:
    if name not in model_registry:
        model_registry[name] = model


def registered_models() -> list[str]:
    return list(model_registry.keys())


def is_registered(name: str) -> bool:
    return name in model_registry


class Model1(nn.Module): ...


register_model("model1", Model1())


# Some other file


def build_registry_literal() -> type[str]:
    return tyro.extras.literal_type_from_choices(registered_models())


ModelName = Annotated[str, tyro.conf.arg(constructor_factory=build_registry_literal)]


# User defines a new model
class Model2(nn.Module): ...


register_model("model2", Model2())


def main(model: ModelName) -> None:
    print(model)


if __name__ == "__main__":
    tyro.cli(main)

I'm thinking there is room for a choices_factory: Callable[..., tuple[str, ...]] | None = None option which would make the spec compatible with this use-case. Is this compatible with the purpose of the API or is the static nature intentional?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions