-
Notifications
You must be signed in to change notification settings - Fork 36
Description
There's already a way to do this, so this is just a feature request. I'm curious if you think it's worth adding/makes sense for the constructor spec API. Consider the following setting:
from typing import Annotated
import torch.nn as nn
import tyro
model_registry: dict[str, nn.Module] = {}
def register_model(name: str, model: nn.Module) -> None:
if name not in model_registry:
model_registry[name] = model
def registered_models() -> list[str]:
return list(model_registry.keys())
def is_registered(name: str) -> bool:
return name in model_registry
class Model1(nn.Module): ...
register_model("model1", Model1())
# Some other file
ModelName = Annotated[
str,
tyro.constructors.PrimitiveConstructorSpec(
nargs=1,
metavar="{" + ",".join(registered_models()[:3]) + ",...}",
instance_from_str=lambda args: args[0],
is_instance=lambda instance: isinstance(instance, str)
and is_registered(instance),
str_from_instance=lambda instance: [instance],
choices=tuple(registered_models()),
),
tyro.conf.arg(
help_behavior_hint=lambda df: f"(default: {df}, run entry.py model_registry)"
),
]
# User defines a new model after ModelName type has been defined
class Model2(nn.Module): ...
register_model("model2", Model2())
# model2 will not show up as a choice
def main(model: ModelName) -> None:
print(model)
if __name__ == "__main__":
tyro.cli(main)
If we have some registry system which constructs a set of choices, and would like to also allow the user to add to the existing choices, the PrimitiveConstructorSpec has a limitation where choices=
has already been defined, so in the above example model2
will not be a possible choice.
I can already accomplish this with a constructor_factory:
from typing import Annotated
import torch.nn as nn
import tyro
model_registry: dict[str, nn.Module] = {}
def register_model(name: str, model: nn.Module) -> None:
if name not in model_registry:
model_registry[name] = model
def registered_models() -> list[str]:
return list(model_registry.keys())
def is_registered(name: str) -> bool:
return name in model_registry
class Model1(nn.Module): ...
register_model("model1", Model1())
# Some other file
def build_registry_literal() -> type[str]:
return tyro.extras.literal_type_from_choices(registered_models())
ModelName = Annotated[str, tyro.conf.arg(constructor_factory=build_registry_literal)]
# User defines a new model
class Model2(nn.Module): ...
register_model("model2", Model2())
def main(model: ModelName) -> None:
print(model)
if __name__ == "__main__":
tyro.cli(main)
I'm thinking there is room for a choices_factory: Callable[..., tuple[str, ...]] | None = None
option which would make the spec compatible with this use-case. Is this compatible with the purpose of the API or is the static nature intentional?