@@ -247,11 +247,14 @@ def __init__(
247247 dropout_rate : float = 0.0 ,
248248 classes : int = 1000 ,
249249 classifier_activation : str = "softmax" ,
250- weights : typing .Optional [str ] = None ,
250+ weights : typing .Optional [str ] = "imagenet" ,
251251 name : str = "ConvNeXtAtto" ,
252252 ** kwargs ,
253253 ):
254254 kwargs = self .fix_config (kwargs )
255+ if weights == "imagenet" :
256+ file_name = "convnextatto_convnext_atto.d2_in1k.keras"
257+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
255258 super ().__init__ (
256259 (2 , 2 , 6 , 2 ),
257260 (40 , 80 , 160 , 320 ),
@@ -284,11 +287,14 @@ def __init__(
284287 dropout_rate : float = 0.0 ,
285288 classes : int = 1000 ,
286289 classifier_activation : str = "softmax" ,
287- weights : typing .Optional [str ] = None ,
290+ weights : typing .Optional [str ] = "imagenet" ,
288291 name : str = "ConvNeXtFemto" ,
289292 ** kwargs ,
290293 ):
291294 kwargs = self .fix_config (kwargs )
295+ if weights == "imagenet" :
296+ file_name = "convnextfemto_convnext_femto.d1_in1k.keras"
297+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
292298 super ().__init__ (
293299 (2 , 2 , 6 , 2 ),
294300 (48 , 96 , 192 , 384 ),
@@ -321,11 +327,14 @@ def __init__(
321327 dropout_rate : float = 0.0 ,
322328 classes : int = 1000 ,
323329 classifier_activation : str = "softmax" ,
324- weights : typing .Optional [str ] = None ,
330+ weights : typing .Optional [str ] = "imagenet" ,
325331 name : str = "ConvNeXtPico" ,
326332 ** kwargs ,
327333 ):
328334 kwargs = self .fix_config (kwargs )
335+ if weights == "imagenet" :
336+ file_name = "convnextpico_convnext_pico.d1_in1k.keras"
337+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
329338 super ().__init__ (
330339 (2 , 2 , 6 , 2 ),
331340 (64 , 128 , 256 , 512 ),
@@ -358,11 +367,14 @@ def __init__(
358367 dropout_rate : float = 0.0 ,
359368 classes : int = 1000 ,
360369 classifier_activation : str = "softmax" ,
361- weights : typing .Optional [str ] = None ,
370+ weights : typing .Optional [str ] = "imagenet" ,
362371 name : str = "ConvNeXtNano" ,
363372 ** kwargs ,
364373 ):
365374 kwargs = self .fix_config (kwargs )
375+ if weights == "imagenet" :
376+ file_name = "convnextnano_convnext_nano.in12k_ft_in1k.keras"
377+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
366378 super ().__init__ (
367379 (2 , 2 , 8 , 2 ),
368380 (80 , 160 , 320 , 640 ),
@@ -395,11 +407,14 @@ def __init__(
395407 dropout_rate : float = 0.0 ,
396408 classes : int = 1000 ,
397409 classifier_activation : str = "softmax" ,
398- weights : typing .Optional [str ] = None ,
410+ weights : typing .Optional [str ] = "imagenet" ,
399411 name : str = "ConvNeXtTiny" ,
400412 ** kwargs ,
401413 ):
402414 kwargs = self .fix_config (kwargs )
415+ if weights == "imagenet" :
416+ file_name = "convnexttiny_convnext_tiny.in12k_ft_in1k.keras"
417+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
403418 super ().__init__ (
404419 (3 , 3 , 9 , 3 ),
405420 (96 , 192 , 384 , 768 ),
@@ -432,11 +447,14 @@ def __init__(
432447 dropout_rate : float = 0.0 ,
433448 classes : int = 1000 ,
434449 classifier_activation : str = "softmax" ,
435- weights : typing .Optional [str ] = None ,
450+ weights : typing .Optional [str ] = "imagenet" ,
436451 name : str = "ConvNeXtSmall" ,
437452 ** kwargs ,
438453 ):
439454 kwargs = self .fix_config (kwargs )
455+ if weights == "imagenet" :
456+ file_name = "convnextsmall_convnext_small.in12k_ft_in1k.keras"
457+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
440458 super ().__init__ (
441459 (3 , 3 , 27 , 3 ),
442460 (96 , 192 , 384 , 768 ),
@@ -469,11 +487,14 @@ def __init__(
469487 dropout_rate : float = 0.0 ,
470488 classes : int = 1000 ,
471489 classifier_activation : str = "softmax" ,
472- weights : typing .Optional [str ] = None ,
490+ weights : typing .Optional [str ] = "imagenet" ,
473491 name : str = "ConvNeXtBase" ,
474492 ** kwargs ,
475493 ):
476494 kwargs = self .fix_config (kwargs )
495+ if weights == "imagenet" :
496+ file_name = "convnextbase_convnext_base.fb_in22k_ft_in1k.keras"
497+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
477498 super ().__init__ (
478499 (3 , 3 , 27 , 3 ),
479500 (128 , 256 , 512 , 1024 ),
@@ -506,11 +527,14 @@ def __init__(
506527 dropout_rate : float = 0.0 ,
507528 classes : int = 1000 ,
508529 classifier_activation : str = "softmax" ,
509- weights : typing .Optional [str ] = None ,
530+ weights : typing .Optional [str ] = "imagenet" ,
510531 name : str = "ConvNeXtLarge" ,
511532 ** kwargs ,
512533 ):
513534 kwargs = self .fix_config (kwargs )
535+ if weights == "imagenet" :
536+ file_name = "convnextlarge_convnext_large.fb_in22k_ft_in1k.keras"
537+ kwargs ["weights_url" ] = f"{ self .default_origin } /{ file_name } "
514538 super ().__init__ (
515539 (3 , 3 , 27 , 3 ),
516540 (192 , 384 , 768 , 1536 ),
@@ -577,4 +601,4 @@ def __init__(
577601add_model_to_registry (ConvNeXtSmall , "imagenet" )
578602add_model_to_registry (ConvNeXtBase , "imagenet" )
579603add_model_to_registry (ConvNeXtLarge , "imagenet" )
580- add_model_to_registry (ConvNeXtXLarge , "imagenet" )
604+ add_model_to_registry (ConvNeXtXLarge )
0 commit comments