Skip to content

Commit 111d205

Browse files
committed
perf: faster LLM loading
using attrs for faster class creation opposed to metaclass Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
1 parent ebcedc3 commit 111d205

File tree

8 files changed

+323
-374
lines changed

8 files changed

+323
-374
lines changed

src/openllm/_configuration.py

Lines changed: 28 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class GenerationConfig:
6666
import openllm
6767

6868
from .exceptions import ForbiddenAttributeError, GpuNotAvailableError, OpenLLMException
69-
from .utils import DEBUG, LazyType, bentoml_cattr, dantic, first_not_none, lenient_issubclass
69+
from .utils import DEBUG, LazyType, bentoml_cattr, codegen, dantic, first_not_none, lenient_issubclass
7070

7171
if hasattr(t, "Required"):
7272
from typing import Required
@@ -85,7 +85,7 @@ class GenerationConfig:
8585
import tensorflow as tf
8686
import torch
8787
import transformers
88-
from attr import _CountingAttr, _make_init, _make_method, _make_repr, _transform_attrs
88+
from attr import _CountingAttr, _make_init, _make_repr, _transform_attrs # type: ignore
8989
from transformers.generation.beam_constraints import Constraint
9090

9191
from ._types import ClickFunctionWrapper, F, O_co, P
@@ -103,7 +103,7 @@ class GenerationConfig:
103103
ItemgetterAny = itemgetter
104104
# NOTE: Using internal API from attr here, since we are actually
105105
# allowing subclass of openllm.LLMConfig to become 'attrs'-ish
106-
from attr._make import _CountingAttr, _make_init, _make_method, _make_repr, _transform_attrs
106+
from attr._make import _CountingAttr, _make_init, _make_repr, _transform_attrs
107107

108108
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
109109
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
@@ -389,89 +389,10 @@ def _populate_value_from_env_var(
389389
return os.environ.get(key, fallback)
390390

391391

392-
# sentinel object for unequivocal object() getattr
393-
_sentinel = object()
394-
395-
396392
def _field_env_key(model_name: str, key: str, suffix: str | None = None) -> str:
397393
return "_".join(filter(None, map(str.upper, ["OPENLLM", model_name, suffix.strip("_") if suffix else "", key])))
398394

399395

400-
def _has_own_attribute(cls: type[t.Any], attrib_name: t.Any):
401-
"""
402-
Check whether *cls* defines *attrib_name* (and doesn't just inherit it).
403-
"""
404-
attr = getattr(cls, attrib_name, _sentinel)
405-
if attr is _sentinel:
406-
return False
407-
408-
for base_cls in cls.__mro__[1:]:
409-
a = getattr(base_cls, attrib_name, None)
410-
if attr is a:
411-
return False
412-
413-
return True
414-
415-
416-
def _get_annotations(cls: type[t.Any]) -> DictStrAny:
417-
"""
418-
Get annotations for *cls*.
419-
"""
420-
if _has_own_attribute(cls, "__annotations__"):
421-
return cls.__annotations__
422-
423-
return DictStrAny()
424-
425-
426-
_classvar_prefixes = (
427-
"typing.ClassVar",
428-
"t.ClassVar",
429-
"ClassVar",
430-
"typing_extensions.ClassVar",
431-
)
432-
433-
434-
def _is_class_var(annot: str | t.Any) -> bool:
435-
"""
436-
Check whether *annot* is a typing.ClassVar.
437-
438-
The string comparison hack is used to avoid evaluating all string
439-
annotations which would put attrs-based classes at a performance
440-
disadvantage compared to plain old classes.
441-
"""
442-
annot = str(annot)
443-
444-
# Annotation can be quoted.
445-
if annot.startswith(("'", '"')) and annot.endswith(("'", '"')):
446-
annot = annot[1:-1]
447-
448-
return annot.startswith(_classvar_prefixes)
449-
450-
451-
def _add_method_dunders(cls: type[t.Any], method_or_cls: _T, _overwrite_doc: str | None = None) -> _T:
452-
"""
453-
Add __module__ and __qualname__ to a *method* if possible.
454-
"""
455-
try:
456-
method_or_cls.__module__ = cls.__module__
457-
except AttributeError:
458-
pass
459-
460-
try:
461-
method_or_cls.__qualname__ = ".".join((cls.__qualname__, method_or_cls.__name__))
462-
except AttributeError:
463-
pass
464-
465-
try:
466-
method_or_cls.__doc__ = (
467-
_overwrite_doc or "Method or class generated by LLMConfig for class " f"{cls.__qualname__}."
468-
)
469-
except AttributeError:
470-
pass
471-
472-
return method_or_cls
473-
474-
475396
# cached it here to save one lookup per assignment
476397
_object_getattribute = object.__getattribute__
477398

@@ -506,8 +427,8 @@ class ModelSettings(t.TypedDict, total=False):
506427
generation_class: t.Type[GenerationConfig]
507428

508429

509-
_ModelSettings: type[attr.AttrsInstance] = _add_method_dunders(
510-
type("__internal__", (ModelSettings,), {"__module__": "openllm._configuration"}),
430+
_ModelSettings: type[attr.AttrsInstance] = codegen.add_method_dunders(
431+
type("__openllm_internal__", (ModelSettings,), {"__module__": "openllm._configuration"}),
511432
attr.make_class(
512433
"ModelSettings",
513434
{
@@ -563,7 +484,7 @@ def structure_settings(cl_: type[LLMConfig], cls: type[t.Any]):
563484
partialed = functools.partial(_field_env_key, model_name=model_name, suffix="generation")
564485

565486
def auto_env_transformers(_: t.Any, fields: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
566-
_has_own_gen = _has_own_attribute(cl_, "GenerationConfig")
487+
_has_own_gen = codegen.has_own_attribute(cl_, "GenerationConfig")
567488
return [
568489
f.evolve(
569490
default=_populate_value_from_env_var(
@@ -576,20 +497,20 @@ def auto_env_transformers(_: t.Any, fields: list[attr.Attribute[t.Any]]) -> list
576497
for f in fields
577498
]
578499

579-
_target: DictStrAny = {
580-
"default_id": settings["default_id"],
581-
"model_ids": settings["model_ids"],
582-
"url": settings.get("url", ""),
583-
"requires_gpu": settings.get("requires_gpu", False),
584-
"trust_remote_code": settings.get("trust_remote_code", False),
585-
"requirements": settings.get("requirements", None),
586-
"name_type": name_type,
587-
"model_name": model_name,
588-
"start_name": start_name,
589-
"env": openllm.utils.ModelEnv(model_name),
590-
"timeout": settings.get("timeout", 3600),
591-
"workers_per_resource": settings.get("workers_per_resource", 1),
592-
"generation_class": attr.make_class(
500+
return cls(
501+
default_id=settings["default_id"],
502+
model_ids=settings["model_ids"],
503+
url=settings.get("url", ""),
504+
requires_gpu=settings.get("requires_gpu", False),
505+
trust_remote_code=settings.get("trust_remote_code", False),
506+
requirements=settings.get("requirements", None),
507+
name_type=name_type,
508+
model_name=model_name,
509+
start_name=start_name,
510+
env=openllm.utils.ModelEnv(model_name),
511+
timeout=settings.get("timeout", 3600),
512+
workers_per_resource=settings.get("workers_per_resource", 1),
513+
generation_class=attr.make_class(
593514
f"{_cl_name}GenerationConfig",
594515
[],
595516
bases=(GenerationConfig,),
@@ -599,18 +520,12 @@ def auto_env_transformers(_: t.Any, fields: list[attr.Attribute[t.Any]]) -> list
599520
repr=True,
600521
field_transformer=auto_env_transformers,
601522
),
602-
}
603-
604-
return cls(**_target)
523+
)
605524

606525

607526
bentoml_cattr.register_structure_hook(_ModelSettings, structure_settings)
608527

609528

610-
def _generate_unique_filename(cls: type[t.Any], func_name: str):
611-
return f"<LLMConfig generated {func_name} {cls.__module__}." f"{getattr(cls, '__qualname__', cls.__name__)}>"
612-
613-
614529
def _setattr_class(attr_name: str, value_var: t.Any, add_dunder: bool = False):
615530
"""
616531
Use the builtin setattr to set *attr_name* to *value_var*.
@@ -632,7 +547,7 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance
632547
"cls": cls,
633548
"_cached_attribute": attributes,
634549
"_cached_getattribute_get": _object_getattribute.__get__,
635-
"__add_dunder": _add_method_dunders,
550+
"__add_dunder": codegen.add_method_dunders,
636551
}
637552
annotations: DictStrAny = {"return": None}
638553

@@ -643,19 +558,9 @@ def _make_assignment_script(cls: type[LLMConfig], attributes: attr.AttrsInstance
643558
lines.append(_setattr_class(arg_name, attr_name, add_dunder=attr_name in _dunder_add))
644559
annotations[attr_name] = field.type
645560

646-
script = "def __assign_attr(cls, %s):\n %s\n" % (", ".join(args), "\n ".join(lines) if lines else "pass")
647-
assign_method = _make_method(
648-
"__assign_attr",
649-
script,
650-
_generate_unique_filename(cls, "__assign_attr"),
651-
globs,
561+
return codegen.generate_function(
562+
cls, "__assign_attr", lines, args=("cls", *args), globs=globs, annotations=annotations
652563
)
653-
assign_method.__annotations__ = annotations
654-
655-
if DEBUG:
656-
logger.info("Generated script:\n%s", script)
657-
658-
return assign_method
659564

660565

661566
_reserved_namespace = {"__config__", "GenerationConfig"}
@@ -841,7 +746,7 @@ def __init_subclass__(cls):
841746
_make_assignment_script(cls, bentoml_cattr.structure(cls, _ModelSettings))(cls)
842747
# process a fields under cls.__dict__ and auto convert them with dantic.Field
843748
cd = cls.__dict__
844-
anns = _get_annotations(cls)
749+
anns = codegen.get_annotations(cls)
845750
partialed = functools.partial(_field_env_key, model_name=cls.__openllm_model_name__)
846751

847752
def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> list[attr.Attribute[t.Any]]:
@@ -861,7 +766,7 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l
861766
these: dict[str, _CountingAttr[t.Any]] = {}
862767
annotated_names: set[str] = set()
863768
for attr_name, typ in anns.items():
864-
if _is_class_var(typ):
769+
if codegen.is_class_var(typ):
865770
continue
866771
annotated_names.add(attr_name)
867772
val = cd.get(attr_name, attr.NOTHING)
@@ -907,7 +812,7 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l
907812
cls.__attrs_attrs__ = attrs
908813
# generate a __attrs_init__ for the subclass, since we will
909814
# implement a custom __init__
910-
cls.__attrs_init__ = _add_method_dunders(
815+
cls.__attrs_init__ = codegen.add_method_dunders(
911816
cls,
912817
_make_init(
913818
cls, # cls (the attrs-decorated class)
@@ -924,7 +829,7 @@ def auto_config_env(_: type[LLMConfig], attrs: list[attr.Attribute[t.Any]]) -> l
924829
),
925830
)
926831
# __repr__ function with the updated fields.
927-
cls.__repr__ = _add_method_dunders(cls, _make_repr(cls.__attrs_attrs__, None, cls))
832+
cls.__repr__ = codegen.add_method_dunders(cls, _make_repr(cls.__attrs_attrs__, None, cls))
928833
# Traverse the MRO to collect existing slots
929834
# and check for an existing __weakref__.
930835
existing_slots: DictStrAny = dict()

0 commit comments

Comments
 (0)