1
- from typing import Any , override
1
+ from typing import Any , Dict , Set , override
2
2
3
- TomlSettingsType = dict [str , dict [str , Any ]]
3
+ TomlSettingsType = Dict [str , Dict [str , Any ]]
4
4
5
5
6
6
class TomlOptions :
@@ -12,6 +12,38 @@ def as_dict(self) -> TomlSettingsType:
12
12
return {f"[{ self .toml_name } ]" : self .__dict__ }
13
13
14
14
15
+ class PoetryDependenciesExtras (TomlOptions ):
16
+ """
17
+ Stores extra information for poetry dependencies for the `pyproject.toml` file.
18
+
19
+ Used for installing cuda with `torch` and `torchvision`.
20
+ """
21
+
22
+ toml_name = "tool.poetry.dependencies"
23
+
24
+ def __init__ (self ) -> None :
25
+ self .items = {
26
+ "torch" : {"source" : "pytorch" },
27
+ "torchvision" : {"source" : "pytorch" },
28
+ }
29
+
30
+ def items_to_str_dict (self ) -> Dict [str , Set [str ]]:
31
+ """Converts the values of the `self.items` dictionary into a strings."""
32
+ str_dict = {}
33
+ for key , d in self .items .items ():
34
+ value_str = []
35
+ for k , v in d .items ():
36
+ value_str .append (f'{ k } = "{ v } "' )
37
+
38
+ str_dict [key ] = {", " .join (value_str )}
39
+
40
+ return str_dict
41
+
42
+ @override
43
+ def as_dict (self ) -> TomlSettingsType :
44
+ return {f"[{ self .toml_name } ]" : self .items_to_str_dict ()}
45
+
46
+
15
47
class PyTorchSource (TomlOptions ):
16
48
"""
17
49
Stores the PyTorch Poetry source details for the `pyproject.toml` file.
@@ -105,7 +137,9 @@ def settings_to_toml(d: TomlSettingsType) -> str:
105
137
toml_str += f"{ header } \n "
106
138
107
139
for key , value in settings .items ():
108
- if isinstance (value , str ):
140
+ if isinstance (value , set ):
141
+ value = str (value ).replace ("'" , " " )
142
+ elif isinstance (value , str ):
109
143
value = f'"{ value } "'
110
144
elif isinstance (value , bool ):
111
145
value = str (value ).lower ()
@@ -125,14 +159,17 @@ def multi_settings_to_toml(d_list: list[TomlOptions]) -> str:
125
159
return full
126
160
127
161
128
- def set_toml_settings (project_name : str ) -> str :
162
+ def set_toml_settings (project_name : str , dl_project : bool ) -> str :
129
163
"""Sets the extra toml settings to add to the end of the `pyproject.toml` file."""
130
164
items = [
131
- PyTorchSource (),
132
165
PytestOptions (project_name ),
133
166
MypyOptions (),
134
167
IsortOptions (),
135
168
BlackOptions (),
136
169
]
137
170
171
+ if dl_project :
172
+ items .insert (0 , PoetryDependenciesExtras ())
173
+ items .insert (1 , PyTorchSource ())
174
+
138
175
return multi_settings_to_toml (items )
0 commit comments