|
3 | 3 | import multiprocessing
|
4 | 4 |
|
5 | 5 | from typing import Optional, List, Literal, Union
|
6 |
| -from pydantic import Field |
| 6 | +from pydantic import Field, root_validator |
7 | 7 | from pydantic_settings import BaseSettings
|
8 | 8 |
|
9 | 9 | import llama_cpp
|
@@ -67,12 +67,12 @@ class ModelSettings(BaseSettings):
|
67 | 67 | n_threads: int = Field(
|
68 | 68 | default=max(multiprocessing.cpu_count() // 2, 1),
|
69 | 69 | ge=1,
|
70 |
| - description="The number of threads to use.", |
| 70 | + description="The number of threads to use. Use -1 for max cpu threads", |
71 | 71 | )
|
72 | 72 | n_threads_batch: int = Field(
|
73 | 73 | default=max(multiprocessing.cpu_count(), 1),
|
74 | 74 | ge=0,
|
75 |
| - description="The number of threads to use when batch processing.", |
| 75 | + description="The number of threads to use when batch processing. Use -1 for max cpu threads", |
76 | 76 | )
|
77 | 77 | rope_scaling_type: int = Field(
|
78 | 78 | default=llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
|
@@ -173,6 +173,16 @@ class ModelSettings(BaseSettings):
|
173 | 173 | default=True, description="Whether to print debug information."
|
174 | 174 | )
|
175 | 175 |
|
| 176 | + @root_validator(pre=True) # pre=True to ensure this runs before any other validation |
| 177 | + def set_dynamic_defaults(cls, values): |
| 178 | + # If n_threads or n_threads_batch is -1, set it to multiprocessing.cpu_count() |
| 179 | + cpu_count = multiprocessing.cpu_count() |
| 180 | + if values.get('n_threads', 0) == -1: |
| 181 | + values['n_threads'] = cpu_count |
| 182 | + if values.get('n_threads_batch', 0) == -1: |
| 183 | + values['n_threads_batch'] = cpu_count |
| 184 | + return values |
| 185 | + |
176 | 186 |
|
177 | 187 | class ServerSettings(BaseSettings):
|
178 | 188 | """Server settings used to configure the FastAPI and Uvicorn server."""
|
|
0 commit comments