Skip to content

Commit baf89ba

Browse files
authored
Merge pull request #122 from ImogenBits/util_scripts
Modernize command line interaction
2 parents cf6ebb4 + 7a3a901 commit baf89ba

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+3357
-2052
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,5 @@ dist
88
site
99
docs/src/pairsum_solver/target
1010
docs/src/pairsum_solver/Cargo.lock
11+
.results
12+
.project

algobattle.ps1

Lines changed: 0 additions & 61 deletions
This file was deleted.

algobattle/battle.py

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from abc import abstractmethod
1010
from inspect import isclass
1111
from typing import (
12+
TYPE_CHECKING,
1213
Any,
1314
Awaitable,
1415
Callable,
@@ -22,9 +23,16 @@
2223
TypeVar,
2324
)
2425

25-
from pydantic import Field, GetCoreSchemaHandler
26+
from pydantic import (
27+
ConfigDict,
28+
Field,
29+
GetCoreSchemaHandler,
30+
ValidationError,
31+
ValidationInfo,
32+
ValidatorFunctionWrapHandler,
33+
)
2634
from pydantic_core import CoreSchema
27-
from pydantic_core.core_schema import tagged_union_schema
35+
from pydantic_core.core_schema import tagged_union_schema, general_wrap_validator_function
2836

2937
from algobattle.program import (
3038
Generator,
@@ -242,20 +250,63 @@ def __get_pydantic_core_schema__(cls, source: Type, handler: GetCoreSchemaHandle
242250
return handler(source)
243251
except NameError:
244252
return handler(source)
253+
245254
match len(Battle._battle_types):
246255
case 0:
247-
return handler(source)
256+
subclass_schema = handler(source)
248257
case 1:
249-
return handler(next(iter(Battle._battle_types.values())))
258+
subclass_schema = handler(next(iter(Battle._battle_types.values())))
250259
case _:
251-
return tagged_union_schema(
260+
subclass_schema = tagged_union_schema(
252261
choices={
253262
battle.Config.model_fields["type"].default: battle.Config.__pydantic_core_schema__
254263
for battle in Battle._battle_types.values()
255264
},
256265
discriminator="type",
257266
)
258267

268+
# we want to validate into the actual battle type's config, so we need to treat them as a tagged union
269+
# but if we're initializing a project the type might not be installed yet, so we want to also parse
270+
# into an unspecified dummy object. This wrap validator will efficiently and transparently act as a tagged
271+
# union when ignore_uninstalled is not set. If it is set it catches only the error of a missing tag, other
272+
# errors are passed through
273+
def check_installed(val: object, handler: ValidatorFunctionWrapHandler, info: ValidationInfo) -> object:
274+
try:
275+
return handler(val)
276+
except ValidationError as e:
277+
union_err = next(filter(lambda err: err["type"] == "union_tag_invalid", e.errors()), None)
278+
if union_err is None:
279+
raise
280+
if info.context is not None and info.context.get("ignore_uninstalled", False):
281+
if info.config is not None:
282+
settings: dict[str, Any] = {
283+
"strict": info.config.get("strict", None),
284+
"from_attributes": info.config.get("from_attributes"),
285+
}
286+
else:
287+
settings = {}
288+
return Battle.FallbackConfig.model_validate(val, context=info.context, **settings)
289+
else:
290+
passed = union_err["input"]["type"]
291+
installed = ", ".join(b.name() for b in Battle._battle_types.values())
292+
raise ValueError(
293+
f"The specified battle type '{passed}' is not installed. Installed types are: {installed}"
294+
)
295+
296+
return general_wrap_validator_function(check_installed, subclass_schema)
297+
298+
class FallbackConfig(Config):
299+
"""Fallback config object to parse into if the proper battle typ isn't installed and we're ignoring installs."""
300+
301+
type: str
302+
303+
model_config = ConfigDict(extra="allow")
304+
305+
if TYPE_CHECKING:
306+
# to hint that we're gonna fill this with arbitrary data belonging to some supposed battle type
307+
def __getattr__(self, __attr: str) -> Any:
308+
...
309+
259310
class UiData(BaseModel):
260311
"""Object containing custom diplay data.
261312
@@ -280,11 +331,12 @@ def load_entrypoints(cls) -> None:
280331
if not (isclass(battle) and issubclass(battle, Battle)):
281332
raise ValueError(f"Entrypoint {entrypoint.name} targets something other than a Battle type")
282333

283-
def __init_subclass__(cls) -> None:
334+
@classmethod
335+
def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
284336
if cls.name() not in Battle._battle_types:
285337
Battle._battle_types[cls.name()] = cls
286338
Battle.Config.model_rebuild(force=True)
287-
return super().__init_subclass__()
339+
return super().__pydantic_init_subclass__(**kwargs)
288340

289341
@abstractmethod
290342
def score(self) -> float:
@@ -367,10 +419,11 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
367419
base_increment = 0
368420
alive = True
369421
reached = 0
422+
self.results.append(0)
370423
cap = config.maximum_size
371424
current = min_size
372425
while alive:
373-
ui.update_battle_data(self.UiData(reached=self.results + [reached], cap=cap))
426+
ui.update_battle_data(self.UiData(reached=self.results, cap=cap))
374427
result = await fight.run(current)
375428
score = result.score
376429
if score < config.minimum_score:
@@ -384,7 +437,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
384437
alive = True
385438
elif current > reached and alive:
386439
# We solved an instance of bigger size than before
387-
reached = current
440+
self.results[-1] = reached = current
388441

389442
if current + 1 > cap:
390443
alive = False
@@ -396,7 +449,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
396449
# We have failed at this value of n already, reset the step size!
397450
current -= base_increment**config.exponent - 1
398451
base_increment = 1
399-
self.results.append(reached)
452+
self.results[-1] = reached
400453

401454
def score(self) -> float:
402455
"""Averages the highest instance size reached in each round."""
@@ -416,7 +469,7 @@ class Config(Battle.Config):
416469

417470
type: Literal["Averaged"] = "Averaged"
418471

419-
instance_size: int = 10
472+
instance_size: int = 25
420473
"""Instance size that will be fought at."""
421474
num_fights: int = 10
422475
"""Number of iterations in each round."""

0 commit comments

Comments
 (0)