Skip to content

Commit 8bd2fa4

Browse files
committed
fix(dl): Fixed dl dependencies.
- Added source information for PyTorch dependencies. Now installs CUDA with PyTorch. - Added ignore logic to stop PyTorch toml info appearing in non-dl projects.
1 parent 6f75290 commit 8bd2fa4

File tree

3 files changed

+54
-6
lines changed

3 files changed

+54
-6
lines changed

zenforge/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ class ProjectType(StrEnum):
2020
API_DEEP_LEARNING = "api-dl"
2121
ALL = "all"
2222

23+
@staticmethod
24+
def dl_project(value: str) -> bool:
25+
"""Checks if the value is a Deep Learning project."""
26+
if "dl" in value or value == "all":
27+
return True
28+
29+
return False
30+
2331

2432
__all__ = [
2533
"ProjectType",

zenforge/config/toml.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Any, override
1+
from typing import Any, Dict, Set, override
22

3-
TomlSettingsType = dict[str, dict[str, Any]]
3+
TomlSettingsType = Dict[str, Dict[str, Any]]
44

55

66
class TomlOptions:
@@ -12,6 +12,38 @@ def as_dict(self) -> TomlSettingsType:
1212
return {f"[{self.toml_name}]": self.__dict__}
1313

1414

15+
class PoetryDependenciesExtras(TomlOptions):
16+
"""
17+
Stores extra information for poetry dependencies for the `pyproject.toml` file.
18+
19+
Used for installing cuda with `torch` and `torchvision`.
20+
"""
21+
22+
toml_name = "tool.poetry.dependencies"
23+
24+
def __init__(self) -> None:
25+
self.items = {
26+
"torch": {"source": "pytorch"},
27+
"torchvision": {"source": "pytorch"},
28+
}
29+
30+
def items_to_str_dict(self) -> Dict[str, Set[str]]:
31+
"""Converts the values of the `self.items` dictionary into a strings."""
32+
str_dict = {}
33+
for key, d in self.items.items():
34+
value_str = []
35+
for k, v in d.items():
36+
value_str.append(f'{k} = "{v}"')
37+
38+
str_dict[key] = {", ".join(value_str)}
39+
40+
return str_dict
41+
42+
@override
43+
def as_dict(self) -> TomlSettingsType:
44+
return {f"[{self.toml_name}]": self.items_to_str_dict()}
45+
46+
1547
class PyTorchSource(TomlOptions):
1648
"""
1749
Stores the PyTorch Poetry source details for the `pyproject.toml` file.
@@ -105,7 +137,9 @@ def settings_to_toml(d: TomlSettingsType) -> str:
105137
toml_str += f"{header}\n"
106138

107139
for key, value in settings.items():
108-
if isinstance(value, str):
140+
if isinstance(value, set):
141+
value = str(value).replace("'", " ")
142+
elif isinstance(value, str):
109143
value = f'"{value}"'
110144
elif isinstance(value, bool):
111145
value = str(value).lower()
@@ -125,14 +159,17 @@ def multi_settings_to_toml(d_list: list[TomlOptions]) -> str:
125159
return full
126160

127161

128-
def set_toml_settings(project_name: str) -> str:
162+
def set_toml_settings(project_name: str, dl_project: bool) -> str:
129163
"""Sets the extra toml settings to add to the end of the `pyproject.toml` file."""
130164
items = [
131-
PyTorchSource(),
132165
PytestOptions(project_name),
133166
MypyOptions(),
134167
IsortOptions(),
135168
BlackOptions(),
136169
]
137170

171+
if dl_project:
172+
items.insert(0, PoetryDependenciesExtras())
173+
items.insert(1, PyTorchSource())
174+
138175
return multi_settings_to_toml(items)

zenforge/create/method.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,10 @@ def __init__(
108108
self.path = path
109109
self.deps = deps
110110

111-
self.toml_extras = set_toml_settings(project_name)
111+
self.toml_extras = set_toml_settings(
112+
project_name,
113+
ProjectType.dl_project(project_type),
114+
)
112115

113116
self.env_name = "venv"
114117
self.python_path = (

0 commit comments

Comments
 (0)