9
9
from abc import abstractmethod
10
10
from inspect import isclass
11
11
from typing import (
12
+ TYPE_CHECKING ,
12
13
Any ,
13
14
Awaitable ,
14
15
Callable ,
22
23
TypeVar ,
23
24
)
24
25
25
- from pydantic import Field , GetCoreSchemaHandler
26
+ from pydantic import (
27
+ ConfigDict ,
28
+ Field ,
29
+ GetCoreSchemaHandler ,
30
+ ValidationError ,
31
+ ValidationInfo ,
32
+ ValidatorFunctionWrapHandler ,
33
+ )
26
34
from pydantic_core import CoreSchema
27
- from pydantic_core .core_schema import tagged_union_schema
35
+ from pydantic_core .core_schema import tagged_union_schema , general_wrap_validator_function
28
36
29
37
from algobattle .program import (
30
38
Generator ,
@@ -242,20 +250,63 @@ def __get_pydantic_core_schema__(cls, source: Type, handler: GetCoreSchemaHandle
242
250
return handler (source )
243
251
except NameError :
244
252
return handler (source )
253
+
245
254
match len (Battle ._battle_types ):
246
255
case 0 :
247
- return handler (source )
256
+ subclass_schema = handler (source )
248
257
case 1 :
249
- return handler (next (iter (Battle ._battle_types .values ())))
258
+ subclass_schema = handler (next (iter (Battle ._battle_types .values ())))
250
259
case _:
251
- return tagged_union_schema (
260
+ subclass_schema = tagged_union_schema (
252
261
choices = {
253
262
battle .Config .model_fields ["type" ].default : battle .Config .__pydantic_core_schema__
254
263
for battle in Battle ._battle_types .values ()
255
264
},
256
265
discriminator = "type" ,
257
266
)
258
267
268
+ # we want to validate into the actual battle type's config, so we need to treat them as a tagged union
269
+ # but if we're initializing a project the type might not be installed yet, so we want to also parse
270
+ # into an unspecified dummy object. This wrap validator will efficiently and transparently act as a tagged
271
+ # union when ignore_uninstalled is not set. If it is set it catches only the error of a missing tag, other
272
+ # errors are passed through
273
+ def check_installed (val : object , handler : ValidatorFunctionWrapHandler , info : ValidationInfo ) -> object :
274
+ try :
275
+ return handler (val )
276
+ except ValidationError as e :
277
+ union_err = next (filter (lambda err : err ["type" ] == "union_tag_invalid" , e .errors ()), None )
278
+ if union_err is None :
279
+ raise
280
+ if info .context is not None and info .context .get ("ignore_uninstalled" , False ):
281
+ if info .config is not None :
282
+ settings : dict [str , Any ] = {
283
+ "strict" : info .config .get ("strict" , None ),
284
+ "from_attributes" : info .config .get ("from_attributes" ),
285
+ }
286
+ else :
287
+ settings = {}
288
+ return Battle .FallbackConfig .model_validate (val , context = info .context , ** settings )
289
+ else :
290
+ passed = union_err ["input" ]["type" ]
291
+ installed = ", " .join (b .name () for b in Battle ._battle_types .values ())
292
+ raise ValueError (
293
+ f"The specified battle type '{ passed } ' is not installed. Installed types are: { installed } "
294
+ )
295
+
296
+ return general_wrap_validator_function (check_installed , subclass_schema )
297
+
298
+ class FallbackConfig (Config ):
299
+ """Fallback config object to parse into if the proper battle typ isn't installed and we're ignoring installs."""
300
+
301
+ type : str
302
+
303
+ model_config = ConfigDict (extra = "allow" )
304
+
305
+ if TYPE_CHECKING :
306
+ # to hint that we're gonna fill this with arbitrary data belonging to some supposed battle type
307
+ def __getattr__ (self , __attr : str ) -> Any :
308
+ ...
309
+
259
310
class UiData (BaseModel ):
260
311
"""Object containing custom diplay data.
261
312
@@ -280,11 +331,12 @@ def load_entrypoints(cls) -> None:
280
331
if not (isclass (battle ) and issubclass (battle , Battle )):
281
332
raise ValueError (f"Entrypoint { entrypoint .name } targets something other than a Battle type" )
282
333
283
- def __init_subclass__ (cls ) -> None :
334
+ @classmethod
335
+ def __pydantic_init_subclass__ (cls , ** kwargs : Any ) -> None :
284
336
if cls .name () not in Battle ._battle_types :
285
337
Battle ._battle_types [cls .name ()] = cls
286
338
Battle .Config .model_rebuild (force = True )
287
- return super ().__init_subclass__ ( )
339
+ return super ().__pydantic_init_subclass__ ( ** kwargs )
288
340
289
341
@abstractmethod
290
342
def score (self ) -> float :
@@ -367,10 +419,11 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
367
419
base_increment = 0
368
420
alive = True
369
421
reached = 0
422
+ self .results .append (0 )
370
423
cap = config .maximum_size
371
424
current = min_size
372
425
while alive :
373
- ui .update_battle_data (self .UiData (reached = self .results + [ reached ] , cap = cap ))
426
+ ui .update_battle_data (self .UiData (reached = self .results , cap = cap ))
374
427
result = await fight .run (current )
375
428
score = result .score
376
429
if score < config .minimum_score :
@@ -384,7 +437,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
384
437
alive = True
385
438
elif current > reached and alive :
386
439
# We solved an instance of bigger size than before
387
- reached = current
440
+ self . results [ - 1 ] = reached = current
388
441
389
442
if current + 1 > cap :
390
443
alive = False
@@ -396,7 +449,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
396
449
# We have failed at this value of n already, reset the step size!
397
450
current -= base_increment ** config .exponent - 1
398
451
base_increment = 1
399
- self .results . append ( reached )
452
+ self .results [ - 1 ] = reached
400
453
401
454
def score (self ) -> float :
402
455
"""Averages the highest instance size reached in each round."""
@@ -416,7 +469,7 @@ class Config(Battle.Config):
416
469
417
470
type : Literal ["Averaged" ] = "Averaged"
418
471
419
- instance_size : int = 10
472
+ instance_size : int = 25
420
473
"""Instance size that will be fought at."""
421
474
num_fights : int = 10
422
475
"""Number of iterations in each round."""
0 commit comments