Skip to content

Commit 331273c

Browse files
authored
Add Serializable Functionality and GuidanceReport Class for Enhanced Results Management (#22)
## Summary Introduces the `Serializable` class to provide serialization capabilities for core Pydantic classes, along with the new `GuidanceReport` class to manage multiple benchmarking reports. The additions add in native support for loading and saving to disk in json and yaml formats. ## Details - **Serializable Class**: Adds the ability to serialize and deserialize objects to/from YAML and JSON formats, and save/load from files. - Implements methods for `to_yaml`, `to_json`, `from_yaml`, `from_json`, `save_file`, and `load_file`. - Introduces `SerializableFileType` enum to handle supported file types. - Includes validation and error handling for file operations. - **GuidanceReport Class**: Manages guidance reports containing benchmarking details across multiple runs. - Inherits from `Serializable` to leverage serialization capabilities. - Contains a list of `TextGenerationBenchmarkReport` objects. - **CLI Integration**: Updates the CLI to use `GuidanceReport` for saving benchmark reports. - Adds `-save-path` option for specifying the path to save the report. - **Tests**: Adds comprehensive unit tests for the new functionality. - Tests for `Serializable` class methods. - Tests for `GuidanceReport` class, including initialization, file operations, and serialization. ## Test Plan - **Automation Testing**: - Added unit tests for `Serializable` class covering YAML and JSON serialization/deserialization, file saving/loading, and error handling. - Added unit tests for `GuidanceReport` class covering initialization and file operations.
1 parent 75dac35 commit 331273c

File tree

8 files changed

+384
-16
lines changed

8 files changed

+384
-16
lines changed

src/guidellm/core/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .distribution import Distribution
2+
from .report import GuidanceReport
23
from .request import TextGenerationRequest
34
from .result import (
45
RequestConcurrencyMeasurement,
@@ -7,6 +8,7 @@
78
TextGenerationError,
89
TextGenerationResult,
910
)
11+
from .serializable import Serializable, SerializableFileType
1012

1113
__all__ = [
1214
"Distribution",
@@ -16,4 +18,7 @@
1618
"TextGenerationBenchmark",
1719
"TextGenerationBenchmarkReport",
1820
"RequestConcurrencyMeasurement",
21+
"Serializable",
22+
"SerializableFileType",
23+
"GuidanceReport",
1924
]

src/guidellm/core/report.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import List
2+
3+
from pydantic import Field
4+
5+
from guidellm.core.serializable import Serializable
6+
from guidellm.core.result import TextGenerationBenchmarkReport
7+
8+
__all__ = [
9+
"GuidanceReport",
10+
]
11+
12+
13+
class GuidanceReport(Serializable):
14+
"""
15+
A class to manage the guidance reports that include the benchmarking details,
16+
potentially across multiple runs, for saving and loading from disk.
17+
"""
18+
19+
benchmarks: List[TextGenerationBenchmarkReport] = Field(
20+
default_factory=list, description="The list of benchmark reports."
21+
)

src/guidellm/core/serializable.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,24 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

3+
import os
34
import yaml
45
from loguru import logger
56
from pydantic import BaseModel, ConfigDict
7+
from enum import Enum
8+
9+
from guidellm.utils import is_file_name
10+
11+
12+
__all__ = ["Serializable", "SerializableFileType"]
13+
14+
15+
class SerializableFileType(Enum):
16+
"""
17+
Enum class for file types supported by Serializable.
18+
"""
19+
20+
YAML = "yaml"
21+
JSON = "json"
622

723

824
class Serializable(BaseModel):
@@ -73,3 +89,97 @@ def from_json(cls, data: str):
7389
obj = cls.model_validate_json(data)
7490

7591
return obj
92+
93+
def save_file(self, path: str, type_: Optional[SerializableFileType] = None) -> str:
94+
"""
95+
Save the model to a file in either YAML or JSON format.
96+
97+
:param path: Path to the exact file or the containing directory.
98+
If it is a directory, the file name will be inferred from the class name.
99+
:param type_: Optional type to save ('yaml' or 'json').
100+
If not provided and the path has an extension,
101+
it will be inferred to save in that format.
102+
If not provided and the path does not have an extension,
103+
it will save in YAML format.
104+
:return: The path to the saved file.
105+
"""
106+
logger.debug("Saving to file... {} with format: {}", path, type_)
107+
108+
if not is_file_name(path):
109+
file_name = f"{self.__class__.__name__.lower()}"
110+
if type_:
111+
file_name += f".{type_.value.lower()}"
112+
else:
113+
file_name += ".yaml"
114+
type_ = SerializableFileType.YAML
115+
path = os.path.join(path, file_name)
116+
117+
if not type_:
118+
extension = path.split(".")[-1].upper()
119+
120+
if extension not in SerializableFileType.__members__:
121+
raise ValueError(
122+
f"Unsupported file extension: {extension}. "
123+
f"Expected one of {', '.join(SerializableFileType.__members__)}) "
124+
f"for {path}"
125+
)
126+
127+
type_ = SerializableFileType[extension]
128+
129+
if type_.name not in SerializableFileType.__members__:
130+
raise ValueError(
131+
f"Unsupported file format: {type_} "
132+
f"(expected 'yaml' or 'json') for {path}"
133+
)
134+
135+
os.makedirs(os.path.dirname(path), exist_ok=True)
136+
137+
with open(path, "w") as file:
138+
if type_ == SerializableFileType.YAML:
139+
file.write(self.to_yaml())
140+
elif type_ == SerializableFileType.JSON:
141+
file.write(self.to_json())
142+
else:
143+
raise ValueError(f"Unsupported file format: {type_}")
144+
145+
logger.info("Successfully saved {} to {}", self.__class__.__name__, path)
146+
147+
return path
148+
149+
@classmethod
150+
def load_file(cls, path: str):
151+
"""
152+
Load a model from a file in either YAML or JSON format.
153+
154+
:param path: Path to the file.
155+
:return: An instance of the model.
156+
"""
157+
logger.debug("Loading from file... {}", path)
158+
159+
if not os.path.exists(path):
160+
raise FileNotFoundError(f"File not found: {path}")
161+
elif not os.path.isfile(path):
162+
raise ValueError(f"Path is not a file: {path}")
163+
164+
extension = path.split(".")[-1].upper()
165+
166+
if extension not in SerializableFileType.__members__:
167+
raise ValueError(
168+
f"Unsupported file extension: {extension}. "
169+
f"Expected one of {', '.join(SerializableFileType.__members__)}) "
170+
f"for {path}"
171+
)
172+
173+
type_ = SerializableFileType[extension]
174+
175+
with open(path, "r") as file:
176+
data = file.read()
177+
178+
if type_ == SerializableFileType.YAML:
179+
obj = cls.from_yaml(data)
180+
elif type_ == SerializableFileType.JSON:
181+
obj = cls.from_json(data)
182+
else:
183+
raise ValueError(f"Unsupported file format: {type_}")
184+
185+
return obj

src/guidellm/main.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import click
22

33
from guidellm.backend import Backend
4-
from guidellm.core import TextGenerationBenchmarkReport
4+
from guidellm.core import GuidanceReport
55
from guidellm.executor import (
66
Executor,
77
rate_type_to_load_gen_mode,
@@ -65,6 +65,12 @@
6565
default=None,
6666
help="Number of requests to send for each rate",
6767
)
68+
@click.option(
69+
"--output-path",
70+
type=str,
71+
default="benchmark_report.json",
72+
help="Path to save benchmark report to",
73+
)
6874
def main(
6975
target,
7076
host,
@@ -80,6 +86,7 @@ def main(
8086
rate,
8187
num_seconds,
8288
num_requests,
89+
output_path,
8390
):
8491
# Create backend
8592
Backend.create(
@@ -127,18 +134,12 @@ def main(
127134
report = executor.run()
128135

129136
# Save or print results
130-
save_report(report, "benchmark_report.json")
131-
print_report(report)
132-
133-
134-
def save_report(report: TextGenerationBenchmarkReport, filename: str):
135-
with open(filename, "w") as file:
136-
file.write(report.to_json())
137-
137+
guidance_report = GuidanceReport()
138+
guidance_report.benchmarks.append(report)
139+
guidance_report.save_file(output_path)
138140

139-
def print_report(report: TextGenerationBenchmarkReport):
140-
for benchmark in report.benchmarks:
141-
print(f"Rate: {benchmark.completed_request_rate}, Results: {benchmark.results}")
141+
print("Guidance Report Complete:")
142+
print(guidance_report)
142143

143144

144145
if __name__ == "__main__":

src/guidellm/utils/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,12 @@
33
PREFERRED_DATA_SPLITS,
44
STANDARD_SLEEP_INTERVAL,
55
)
6+
from .functions import is_file_name, is_directory_name
67

7-
__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS", "STANDARD_SLEEP_INTERVAL"]
8+
__all__ = [
9+
"PREFERRED_DATA_COLUMNS",
10+
"PREFERRED_DATA_SPLITS",
11+
"STANDARD_SLEEP_INTERVAL",
12+
"is_file_name",
13+
"is_directory_name",
14+
]

src/guidellm/utils/functions.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
3+
4+
__all__ = [
5+
"is_file_name",
6+
"is_directory_name",
7+
]
8+
9+
10+
def is_file_name(path: str) -> bool:
11+
"""
12+
Check if the path has an extension and is not a directory.
13+
14+
:param path: The path to check.
15+
:type path: str
16+
:return: True if the path is a file naming convention.
17+
"""
18+
19+
_, ext = os.path.splitext(path)
20+
21+
return bool(ext) and not path.endswith(os.path.sep)
22+
23+
24+
def is_directory_name(path: str) -> bool:
25+
"""
26+
Check if the path does not have an extension and is a directory.
27+
28+
:param path: The path to check.
29+
:type path: str
30+
:return: True if the path is a directory naming convention.
31+
"""
32+
_, ext = os.path.splitext(path)
33+
return not ext or path.endswith(os.path.sep)

tests/unit/core/test_report.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import pytest
2+
import os
3+
import tempfile
4+
from guidellm.core import (
5+
TextGenerationBenchmark,
6+
TextGenerationBenchmarkReport,
7+
TextGenerationResult,
8+
TextGenerationRequest,
9+
TextGenerationError,
10+
Distribution,
11+
GuidanceReport,
12+
)
13+
14+
15+
@pytest.fixture
16+
def sample_benchmark_report() -> TextGenerationBenchmarkReport:
17+
sample_request = TextGenerationRequest(prompt="sample prompt")
18+
sample_distribution = Distribution()
19+
sample_result = TextGenerationResult(
20+
request=sample_request,
21+
prompt="sample prompt",
22+
prompt_word_count=2,
23+
prompt_token_count=2,
24+
output="sample output",
25+
output_word_count=2,
26+
output_token_count=2,
27+
last_time=None,
28+
first_token_set=False,
29+
start_time=None,
30+
end_time=None,
31+
first_token_time=None,
32+
decode_times=sample_distribution,
33+
)
34+
sample_error = TextGenerationError(request=sample_request, message="sample error")
35+
sample_benchmark = TextGenerationBenchmark(
36+
mode="async",
37+
rate=1.0,
38+
results=[sample_result],
39+
errors=[sample_error],
40+
concurrencies=[],
41+
)
42+
return TextGenerationBenchmarkReport(
43+
benchmarks=[sample_benchmark], args=[{"arg1": "value1"}]
44+
)
45+
46+
47+
def compare_guidance_reports(report1: GuidanceReport, report2: GuidanceReport) -> bool:
48+
return report1 == report2
49+
50+
51+
@pytest.mark.smoke
52+
def test_guidance_report_initialization():
53+
report = GuidanceReport()
54+
assert report.benchmarks == []
55+
56+
57+
@pytest.mark.smoke
58+
def test_guidance_report_initialization_with_params(sample_benchmark_report):
59+
report = GuidanceReport(benchmarks=[sample_benchmark_report])
60+
assert report.benchmarks == [sample_benchmark_report]
61+
62+
63+
@pytest.mark.smoke
64+
def test_guidance_report_file(sample_benchmark_report):
65+
report = GuidanceReport(benchmarks=[sample_benchmark_report])
66+
with tempfile.TemporaryDirectory() as temp_dir:
67+
file_path = os.path.join(temp_dir, "report.yaml")
68+
report.save_file(file_path)
69+
loaded_report = GuidanceReport.load_file(file_path)
70+
assert compare_guidance_reports(report, loaded_report)
71+
72+
73+
@pytest.mark.regression
74+
def test_guidance_report_json(sample_benchmark_report):
75+
report = GuidanceReport(benchmarks=[sample_benchmark_report])
76+
json_str = report.to_json()
77+
loaded_report = GuidanceReport.from_json(json_str)
78+
assert compare_guidance_reports(report, loaded_report)
79+
80+
81+
@pytest.mark.regression
82+
def test_guidance_report_yaml(sample_benchmark_report):
83+
report = GuidanceReport(benchmarks=[sample_benchmark_report])
84+
yaml_str = report.to_yaml()
85+
loaded_report = GuidanceReport.from_yaml(yaml_str)
86+
assert compare_guidance_reports(report, loaded_report)

0 commit comments

Comments
 (0)