Skip to content

Commit 4deefd6

Browse files
committed
extend test_eq_type
previously this would pass tests when comparting tensors with different dtypes example: `test_eq_type(torch.zeros(10), torch.zeros(10, dtype=torch.float64))` would pass now it does not
1 parent 86337ba commit 4deefd6

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

fastcore/_modidx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,4 +594,4 @@
594594
'fastcore.xtras.truncstr': ('xtras.html#truncstr', 'fastcore/xtras.py'),
595595
'fastcore.xtras.untar_dir': ('xtras.html#untar_dir', 'fastcore/xtras.py'),
596596
'fastcore.xtras.utc2local': ('xtras.html#utc2local', 'fastcore/xtras.py'),
597-
'fastcore.xtras.walk': ('xtras.html#walk', 'fastcore/xtras.py')}}}
597+
'fastcore.xtras.walk': ('xtras.html#walk', 'fastcore/xtras.py')}}}

fastcore/test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def test_eq_type(a,b):
4141
"`test` that `a==b` and are same type"
4242
test_eq(a,b)
4343
test_eq(type(a),type(b))
44-
if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b))
44+
if isinstance(a,(list,tuple)): test_eq(map(type,a),map(type,b)) # type of each element
45+
if isinstance(a, (torch.Tensor, pd.Series, np.ndarray)): test_eq(a.dtype, b.dtype) # dtypes of both tensors
4546

4647
# %% ../nbs/00_test.ipynb 27
4748
def test_ne(a,b):

nbs/00_test.ipynb

Lines changed: 9 additions & 5 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)