Skip to content

Commit 7830009

Browse files
committed
formatting
1 parent af77ece commit 7830009

File tree

5 files changed

+61
-32
lines changed

5 files changed

+61
-32
lines changed

README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ y = torch.randn(100, requires_grad=True, device='cuda')
5555
plt.plot(x, y, '.') # easy and simple
5656
```
5757

58+
## Requirements
59+
Tested using `torch>=1.6` and `matplotlib>=3.3.3` but should perfectly work with
60+
both earlier and later versions.
61+
62+
## Licence
63+
64+
Please observe the Apache 2.0 license that is listed in this repository.
65+
66+
## BibTeX
67+
If you want to cite the framework feel free to use this (but only if you loved it 😊):
68+
69+
```bibtex
70+
@article{detlefsen2021torchplot,
71+
title={TorchPlot},
72+
author={Detlefsen, Nicki S.},
73+
journal={GitHub. Note: https://github.com/CenterBioML/torchplot},
74+
year={2021}
75+
}
76+
```
77+
5878

5979

6080

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
torch>=1.3
2-
numpy>=1.16.4
2+
matplotlib>=3.3.3

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,14 @@ def load_readme(path_dir=PATH_ROOT):
5656
version=torchplot.__version__,
5757
description=torchplot.__docs__,
5858
long_description=load_readme(PATH_ROOT),
59+
long_description_content_type="text/markdown",
5960
author=torchplot.__author__,
6061
author_email=torchplot.__author_email__,
6162
license=torchplot.__license__,
6263
packages=find_packages(exclude=["tests", "tests/*"]),
6364
python_requires=">=3.8",
64-
install_requires=['torch>=1.3', 'matplotlib>=3.3.3'],
65-
download_url="https://github.com/CenterBioML/torchplot/archive/0.1.0.zip",
65+
install_requires=["torch>=1.6", "matplotlib>=3.3.3"],
66+
download_url="https://github.com/CenterBioML/torchplot/archive/0.1.1.zip",
6667
classifiers=[
6768
"Environment :: Console",
6869
"Natural Language :: English",

tests/test_torchplot.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,36 @@
2424

2525
Inputs = namedtuple("case", ["x", "y"])
2626

27-
_cpu_cases = [Inputs(x=torch.randn(100,), y=torch.randn(100,)),
28-
Inputs(x=torch.randn(100, requires_grad=True), y=torch.randn(100,requires_grad=True)),
29-
# test that list/numpy arrays still works
30-
Inputs(x=[1,2,3,4], y=[1,2,3,4]),
31-
Inputs(x=np.random.randn(100,), y=np.random.randn(100,)),
32-
# test that we can mix
33-
Inputs(x=torch.randn(100,), y=torch.randn(100, requires_grad=True)),
34-
Inputs(x=np.random.randn(100,), y=torch.randn(100, requires_grad=True)),
35-
Inputs(x=torch.randn(5,), y=[1,2,3,4,5]),
36-
]
27+
_cpu_cases = [
28+
Inputs(x=torch.randn(100), y=torch.randn(100)),
29+
Inputs(x=torch.randn(100, requires_grad=True), y=torch.randn(100, requires_grad=True)),
30+
# test that list/numpy arrays still works
31+
Inputs(x=[1, 2, 3, 4], y=[1, 2, 3, 4]),
32+
Inputs(x=np.random.randn(100), y=np.random.randn(100)),
33+
# test that we can mix
34+
Inputs(x=torch.randn(100), y=torch.randn(100, requires_grad=True)),
35+
Inputs(x=np.random.randn(100), y=torch.randn(100, requires_grad=True)),
36+
Inputs(x=torch.randn(5), y=[1, 2, 3, 4, 5]),
37+
]
3738

38-
_gpu_cases = [Inputs(x=torch.randn(100, device='cuda'), y=torch.randn(100, device='cuda')),
39-
Inputs(x=torch.randn(100,requires_grad=True, device='cuda'), y=torch.randn(100,requires_grad=True, device='cuda')),
40-
]
39+
_gpu_cases = [
40+
Inputs(x=torch.randn(100, device="cuda"), y=torch.randn(100, device="cuda")),
41+
Inputs(
42+
x=torch.randn(100, requires_grad=True, device="cuda"), y=torch.randn(100, requires_grad=True, device="cuda")
43+
),
44+
]
4145

4246

43-
44-
_members_to_check = [name for name, member in getmembers(plt)
45-
if isfunction(member) and not name.startswith('_')]
47+
_members_to_check = [name for name, member in getmembers(plt) if isfunction(member) and not name.startswith("_")]
4648

4749

4850
def string_compare(text1, text2):
4951
if text1 is None and text2 is None:
5052
return True
5153
remove = string.punctuation + string.whitespace
52-
return text1.translate(str.maketrans(dict.fromkeys(remove))) == text2.translate(str.maketrans(dict.fromkeys(remove)))
54+
return text1.translate(str.maketrans(dict.fromkeys(remove))) == text2.translate(
55+
str.maketrans(dict.fromkeys(remove))
56+
)
5357

5458

5559
@pytest.mark.parametrize("member", _members_to_check)
@@ -59,15 +63,14 @@ def test_members(member):
5963
assert member in dir(tp)
6064

6165

62-
@pytest.mark.parametrize('test_case', _cpu_cases)
66+
@pytest.mark.parametrize("test_case", _cpu_cases)
6367
def test_cpu(test_case):
6468
""" test that it works on cpu """
65-
assert tp.plot(test_case.x, test_case.y, '.')
69+
assert tp.plot(test_case.x, test_case.y, ".")
6670

6771

68-
@pytest.mark.skipif(not torch.cuda.is_available(), reason='test requires cuda')
69-
@pytest.mark.parametrize('test_case', _gpu_cases)
72+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda")
73+
@pytest.mark.parametrize("test_case", _gpu_cases)
7074
def test_gpu(test_case):
7175
""" test that it works on gpu """
72-
assert tp.plot(test_case.x, test_case.y, '.')
73-
76+
assert tp.plot(test_case.x, test_case.y, ".")

torchplot/core.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,20 @@ def convert(arg):
3131

3232
return outargs, kwargs
3333

34+
3435
# Iterate over all members of 'plt' in order to duplicate them
3536
for name, member in getmembers(plt):
3637
if isfunction(member):
3738
doc = getdoc(member)
3839
strdoc = "" if doc is None else doc
39-
exec(('def {name}(*args, **kwargs):\n' +
40-
'\t"""{doc}"""\n' +
41-
'\tnew_args, new_kwargs = _torch2np(*args, **kwargs)\n' +
42-
'\treturn plt.{name}(*new_args, **new_kwargs)').format(name=name, doc=strdoc))
40+
exec(
41+
(
42+
"def {name}(*args, **kwargs):\n"
43+
+ '\t"""{doc}"""\n'
44+
+ "\tnew_args, new_kwargs = _torch2np(*args, **kwargs)\n"
45+
+ "\treturn plt.{name}(*new_args, **new_kwargs)"
46+
).format(name=name, doc=strdoc)
47+
)
4348
else:
44-
exec('{name} = plt.{name}'.format(name=name))
45-
#break
49+
exec("{name} = plt.{name}".format(name=name))
50+
# break

0 commit comments

Comments
 (0)