Skip to content

Commit 774d175

Browse files
committed
Remove uint16, uint32, and uint64 from the pytorch dtypes() output
1 parent be4fa68 commit 774d175

File tree

1 file changed

+4
-19
lines changed

1 file changed

+4
-19
lines changed

array_api_compat/torch/_info.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,10 @@ def _dtypes(self, kind):
164164
int16 = torch.int16
165165
int32 = torch.int32
166166
int64 = torch.int64
167-
uint8 = getattr(torch, "uint8", None)
168-
uint16 = getattr(torch, "uint16", None)
169-
uint32 = getattr(torch, "uint32", None)
170-
uint64 = getattr(torch, "uint64", None)
167+
uint8 = torch.uint8
168+
# uint16, uint32, and uint64 are present in newer versions of pytorch,
169+
# but they aren't generally supported by the array API functions, so
170+
# we omit them from this function.
171171
float32 = torch.float32
172172
float64 = torch.float64
173173
complex64 = torch.complex64
@@ -181,9 +181,6 @@ def _dtypes(self, kind):
181181
"int32": int32,
182182
"int64": int64,
183183
"uint8": uint8,
184-
"uint16": uint16,
185-
"uint32": uint32,
186-
"uint64": uint64,
187184
"float32": float32,
188185
"float64": float64,
189186
"complex64": complex64,
@@ -201,9 +198,6 @@ def _dtypes(self, kind):
201198
if kind == "unsigned integer":
202199
return {
203200
"uint8": uint8,
204-
"uint16": uint16,
205-
"uint32": uint32,
206-
"uint64": uint64,
207201
}
208202
if kind == "integral":
209203
return {
@@ -212,9 +206,6 @@ def _dtypes(self, kind):
212206
"int32": int32,
213207
"int64": int64,
214208
"uint8": uint8,
215-
"uint16": uint16,
216-
"uint32": uint32,
217-
"uint64": uint64,
218209
}
219210
if kind == "real floating":
220211
return {
@@ -233,9 +224,6 @@ def _dtypes(self, kind):
233224
"int32": int32,
234225
"int64": int64,
235226
"uint8": uint8,
236-
"uint16": uint16,
237-
"uint32": uint32,
238-
"uint64": uint64,
239227
"float32": float32,
240228
"float64": float64,
241229
"complex64": complex64,
@@ -305,9 +293,6 @@ def dtypes(self, *, device=None, kind=None):
305293
"""
306294
res = self._dtypes(kind)
307295
for k, v in res.copy().items():
308-
if v is None:
309-
del res[k]
310-
continue
311296
try:
312297
torch.empty((0,), dtype=v, device=device)
313298
except:

0 commit comments

Comments
 (0)