Skip to content

Commit 407b339

Browse files
SSYernarfacebook-github-bot
authored andcommitted
Support for default_factory in cmd_conf (pytorch#3167)
Summary: Pull Request resolved: pytorch#3167 The cmd_conf decorator now properly handles dataclass fields with default_factory by calling the factory function to get the actual default value. This ensures dataclass fields recieve their intended default values when using argparse instead of missing type objects. Reviewed By: aliafzal Differential Revision: D77896269 fbshipit-source-id: 446ae84b1728d1786031c62527bf76bbf22059dc
1 parent 9e96a37 commit 407b339

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torchrec/distributed/benchmark/benchmark_utils.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import os
2121
import time
2222
import timeit
23-
from dataclasses import dataclass, fields, is_dataclass
23+
from dataclasses import dataclass, fields, is_dataclass, MISSING
2424
from enum import Enum
2525
from typing import (
2626
Any,
@@ -500,8 +500,15 @@ def wrapper() -> Any:
500500
ftype = non_none[0]
501501
origin = get_origin(ftype)
502502

503+
# Handle default_factory value
504+
default_value = (
505+
f.default_factory() # pyre-ignore [29]
506+
if f.default_factory is not MISSING
507+
else f.default
508+
)
509+
503510
arg_kwargs = {
504-
"default": f.default,
511+
"default": default_value,
505512
"help": f"({cls.__name__}) {arg_name}",
506513
}
507514

0 commit comments

Comments
 (0)