Skip to content

Commit 9cf4605

Browse files
authored
Merge pull request #4 from opendilab/release/0.2.1
release(hansbug): use version 0.2.1
2 parents fe5f681 + 716b4b9 commit 9cf4605

File tree

12 files changed

+200
-27
lines changed

12 files changed

+200
-27
lines changed

.github/workflows/badge.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ name: Badge Creation
33
on:
44
push:
55
branches: [ main, 'badge/*', 'doc/*' ]
6-
pull_request:
7-
branches: [ main, 'badge/*', 'doc/*' ]
86

97
jobs:
108
update-badges:

.github/workflows/test.yml

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ name: Code Test
22

33
on:
44
- push
5-
- pull_request
65

76
jobs:
87
unittest:
@@ -18,27 +17,31 @@ jobs:
1817
- '3.7'
1918
- '3.8'
2019
- '3.9'
20+
numpy-version:
21+
- '1.18.0'
22+
- '1.20.0'
23+
- '1.22.0'
2124
torch-version:
25+
- '1.2.0'
2226
- '1.4.0'
23-
- '1.5.0'
2427
- '1.6.0'
25-
- '1.7.0'
2628
- '1.8.0'
27-
- '1.9.0'
2829
- '1.10.0'
2930
exclude:
30-
- os: 'ubuntu-18.04'
31-
python-version: '3.9'
31+
- python-version: '3.6'
32+
numpy-version: '1.20.0'
33+
- python-version: '3.6'
34+
numpy-version: '1.22.0'
35+
- python-version: '3.7'
36+
numpy-version: '1.22.0'
37+
- python-version: '3.8'
38+
torch-version: '1.2.0'
39+
- python-version: '3.9'
40+
torch-version: '1.2.0'
41+
- python-version: '3.9'
3242
torch-version: '1.4.0'
33-
- os: 'ubuntu-18.04'
34-
python-version: '3.9'
35-
torch-version: '1.5.0'
36-
- os: 'ubuntu-18.04'
37-
python-version: '3.9'
43+
- python-version: '3.9'
3844
torch-version: '1.6.0'
39-
- os: 'ubuntu-18.04'
40-
python-version: '3.9'
41-
torch-version: '1.7.0'
4245

4346
steps:
4447
- name: Checkout code
@@ -60,6 +63,14 @@ jobs:
6063
run: |
6164
python -m pip install --upgrade pip
6265
pip install --upgrade flake8 setuptools wheel twine
66+
- name: Install latest numpy
67+
if: ${{ matrix.numpy-version == 'latest' }}
68+
run: |
69+
pip install 'numpy'
70+
- name: Install numpy v${{ matrix.numpy-version }}
71+
if: ${{ matrix.numpy-version != 'latest' }}
72+
run: |
73+
pip install 'numpy==${{ matrix.numpy-version }}'
6374
- name: Install latest pytorch
6475
if: ${{ matrix.torch-version == 'latest' }}
6576
run: |

docs/source/api_doc/numpy/funcs.rst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def _doc_process(doc: str) -> str:
7373
print_title(f"Description From Numpy v{_short_version}", levelc='-', file=p_func)
7474
current_module(np.__name__, file=p_func)
7575

76-
_origin_doc = _doc_process(_origin.__doc__ or "")
76+
_origin_doc = _doc_process(_origin.__doc__ or "").lstrip()
7777
_doc_lines = _origin_doc.splitlines()
7878
_first_line, _other_lines = _doc_lines[0], _doc_lines[1:]
7979
if _first_line.strip():

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
treevalue>=1.2.0
1+
treevalue>=1.3.0
22
torch>=1.1.0,<=1.10.0
33
hbutils>=0.0.1
44
numpy

test/numpy/test_array.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import numpy as np
22
import pytest
3+
import torch
34

45
import treetensor.numpy as tnp
6+
import treetensor.torch as ttorch
57
from treetensor.common import Object
68

79

@@ -233,3 +235,22 @@ def test_tolist(self):
233235
'd': [0, 0, 0.0],
234236
}
235237
})
238+
239+
def test_tensor(self):
240+
assert ttorch.isclose(self._DEMO_1.tensor().double(), ttorch.Tensor({
241+
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]]),
242+
'b': ttorch.Tensor([1, 3, 5, 7]),
243+
'x': {
244+
'c': ttorch.Tensor([[11], [23]]),
245+
'd': ttorch.Tensor([3, 9, 11.0])
246+
}
247+
}).double()).all()
248+
249+
assert (self._DEMO_1.tensor(dtype=torch.float64) == ttorch.Tensor({
250+
'a': ttorch.Tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float64),
251+
'b': ttorch.Tensor([1, 3, 5, 7], dtype=torch.float64),
252+
'x': {
253+
'c': ttorch.Tensor([[11], [23]], dtype=torch.float64),
254+
'd': ttorch.Tensor([3, 9, 11.0], dtype=torch.float64),
255+
}
256+
})).all()

test/numpy/test_funcs.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,97 @@ def test_array_equal(self):
127127
'd': True,
128128
}
129129
})
130+
131+
def test_zeros(self):
132+
zs = tnp.zeros((2, 3))
133+
assert isinstance(zs, np.ndarray)
134+
assert np.allclose(zs, np.zeros((2, 3)))
135+
136+
zs = tnp.zeros({'a': (2, 3), 'c': {'x': (3, 4)}})
137+
assert tnp.allclose(zs, tnp.ndarray({
138+
'a': np.zeros((2, 3)),
139+
'c': {'x': np.zeros((3, 4))}
140+
}))
141+
142+
def test_ones(self):
143+
zs = tnp.ones((2, 3))
144+
assert isinstance(zs, np.ndarray)
145+
assert np.allclose(zs, np.ones((2, 3)))
146+
147+
zs = tnp.ones({'a': (2, 3), 'c': {'x': (3, 4)}})
148+
assert tnp.allclose(zs, tnp.ndarray({
149+
'a': np.ones((2, 3)),
150+
'c': {'x': np.zeros((3, 4))}
151+
}))
152+
153+
def test_stack(self):
154+
a = np.array([1, 2, 3])
155+
b = np.array([2, 3, 4])
156+
nd = tnp.stack((a, b))
157+
assert isinstance(nd, np.ndarray)
158+
assert np.allclose(nd, np.array([[1, 2, 3],
159+
[2, 3, 4]]))
160+
161+
a = tnp.array({
162+
'a': [1, 2, 3],
163+
'c': {'x': [11, 22, 33]},
164+
})
165+
b = tnp.array({
166+
'a': [2, 3, 4],
167+
'c': {'x': [22, 33, 44]},
168+
})
169+
nd = tnp.stack((a, b))
170+
assert tnp.allclose(nd, tnp.array({
171+
'a': [[1, 2, 3], [2, 3, 4]],
172+
'c': {'x': [[11, 22, 33], [22, 33, 44]]},
173+
}))
174+
175+
def test_concatenate(self):
176+
a = np.array([[1, 2], [3, 4]])
177+
b = np.array([[5, 6]])
178+
nd = tnp.concatenate((a, b), axis=0)
179+
assert isinstance(nd, np.ndarray)
180+
assert np.allclose(nd, np.array([[1, 2],
181+
[3, 4],
182+
[5, 6]]))
183+
184+
a = tnp.array({
185+
'a': [[1, 2], [3, 4]],
186+
'c': {'x': [[11, 22], [33, 44]]},
187+
})
188+
b = tnp.array({
189+
'a': [[5, 6]],
190+
'c': {'x': [[55, 66]]},
191+
})
192+
nd = tnp.concatenate((a, b), axis=0)
193+
assert tnp.allclose(nd, tnp.array({
194+
'a': [[1, 2], [3, 4], [5, 6]],
195+
'c': {'x': [[11, 22], [33, 44], [55, 66]]},
196+
}))
197+
198+
def test_split(self):
199+
x = np.arange(9.0)
200+
ns = tnp.split(x, 3)
201+
assert len(ns) == 3
202+
assert isinstance(ns[0], np.ndarray)
203+
assert np.allclose(ns[0], np.array([0.0, 1.0, 2.0]))
204+
assert isinstance(ns[1], np.ndarray)
205+
assert np.allclose(ns[1], np.array([3.0, 4.0, 5.0]))
206+
assert isinstance(ns[2], np.ndarray)
207+
assert np.allclose(ns[2], np.array([6.0, 7.0, 8.0]))
208+
209+
xx = tnp.arange(tnp.ndarray({'a': 9.0, 'c': {'x': 18.0}}))
210+
ns = tnp.split(xx, 3)
211+
assert len(ns) == 3
212+
assert tnp.allclose(ns[0], tnp.array({
213+
'a': [0.0, 1.0, 2.0],
214+
'c': {'x': [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]},
215+
}))
216+
assert tnp.allclose(ns[1], tnp.array({
217+
'a': [3.0, 4.0, 5.0],
218+
'c': {'x': [6.0, 7.0, 8.0, 9.0, 10.0, 11.0]},
219+
}))
220+
assert tnp.allclose(ns[2], tnp.array({
221+
'a': [6.0, 7.0, 8.0],
222+
'c': {'x': [12.0, 13.0, 14.0, 15.0, 16.0, 17.0]},
223+
}))

test/torch/tensor/test_reduction.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import treetensor.torch as ttorch
55
from .base import choose_mark
66

7+
bool_init_dtype = torch.tensor([True, False]).dtype
8+
79

810
# noinspection DuplicatedCode,PyUnresolvedReferences
911
class TestTorchTensorReduction:
@@ -14,15 +16,15 @@ def test_all(self):
1416
'b': {'x': [[True, True, ], [True, True, ]]}
1517
}).all()
1618
assert isinstance(t1, torch.Tensor)
17-
assert t1.dtype == torch.bool
19+
assert t1.dtype == bool_init_dtype
1820
assert t1
1921

2022
t2 = ttorch.Tensor({
2123
'a': [True, False],
2224
'b': {'x': [[True, True, ], [True, True, ]]}
2325
}).all()
2426
assert isinstance(t2, torch.Tensor)
25-
assert t2.dtype == torch.bool
27+
assert t2.dtype == bool_init_dtype
2628
assert not t2
2729

2830
t3 = ttorch.tensor({
@@ -48,15 +50,15 @@ def test_any(self):
4850
'b': {'x': [[False, False, ], [False, False, ]]}
4951
}).any()
5052
assert isinstance(t1, torch.Tensor)
51-
assert t1.dtype == torch.bool
53+
assert t1.dtype == bool_init_dtype
5254
assert t1
5355

5456
t2 = ttorch.Tensor({
5557
'a': [False, False],
5658
'b': {'x': [[False, False, ], [False, False, ]]}
5759
}).any()
5860
assert isinstance(t2, torch.Tensor)
59-
assert t2.dtype == torch.bool
61+
assert t2.dtype == bool_init_dtype
6062
assert not t2
6163

6264
t3 = ttorch.Tensor({

treetensor/common/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _load_func(name):
3030
@doc_from_base()
3131
@return_self_dec
3232
@post_process(auto_tree_cls)
33-
@func_treelize(return_type=TreeValue, rise=True)
33+
@func_treelize(return_type=TreeValue, subside=True, rise=True)
3434
@wraps(func, assigned=('__name__',), updated=())
3535
def _new_func(*args, **kwargs):
3636
return func(*args, **kwargs)

treetensor/common/wrappers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from functools import wraps
2-
from operator import itemgetter
32

4-
from treevalue import TreeValue, walk
3+
from treevalue import TreeValue, flatten_values
54

65
__all__ = [
76
'ireduce',
@@ -17,7 +16,7 @@ def _decorator(func):
1716
def _new_func(*args, **kwargs):
1817
result = func(*args, **kwargs)
1918
if isinstance(result, TreeValue):
20-
it = map(itemgetter(1), walk(result, include_nodes=False))
19+
it = flatten_values(result)
2120
return rfunc(piter(it))
2221
else:
2322
return result

treetensor/config/meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
__TITLE__ = "DI-treetensor"
88

99
#: Version of this project.
10-
__VERSION__ = "0.2.0"
10+
__VERSION__ = "0.2.1"
1111

1212
#: Short description of the project, will be included in ``setup.py``.
1313
__DESCRIPTION__ = 'A flexible, generalized tree-based tensor structure.'

treetensor/numpy/array.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from functools import lru_cache
2+
13
import numpy
4+
import torch
25
from treevalue import method_treelize
36

47
from .base import TreeNumpy
@@ -12,6 +15,12 @@
1215
_ArrayProxy, _InstanceArrayProxy = get_tree_proxy(numpy.ndarray)
1316

1417

18+
@lru_cache()
19+
def _get_tensor_class(args0):
20+
from ..torch import Tensor
21+
return Tensor(args0)
22+
23+
1524
class _BaseArrayMeta(clsmeta(numpy.asarray, allow_dict=True)):
1625
pass
1726

@@ -92,6 +101,13 @@ def all(self: numpy.ndarray, *args, **kwargs):
92101
def any(self: numpy.ndarray, *args, **kwargs):
93102
return self.any(*args, **kwargs)
94103

104+
@method_treelize(return_type=_get_tensor_class)
105+
def tensor(self: numpy.ndarray, *args, **kwargs):
106+
tensor_: torch.Tensor = torch.from_numpy(self)
107+
if args or kwargs:
108+
tensor_ = tensor_.to(*args, **kwargs)
109+
return tensor_
110+
95111
@method_treelize()
96112
def __eq__(self, other):
97113
"""

treetensor/numpy/funcs.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
__all__ = [
1414
'all', 'any', 'array',
1515
'equal', 'array_equal',
16+
'stack', 'concatenate', 'split',
17+
'zeros', 'ones',
1618
]
1719

1820
func_treelize = post_process(post_process(args_mapping(
@@ -71,3 +73,33 @@ def array(p_object, *args, **kwargs):
7173
})
7274
"""
7375
return np.array(p_object, *args, **kwargs)
76+
77+
78+
@doc_from(np.stack)
79+
@func_treelize(subside=True)
80+
def stack(arrays, *args, **kwargs):
81+
return np.stack(arrays, *args, **kwargs)
82+
83+
84+
@doc_from(np.concatenate)
85+
@func_treelize(subside=True)
86+
def concatenate(arrays, *args, **kwargs):
87+
return np.concatenate(arrays, *args, **kwargs)
88+
89+
90+
@doc_from(np.split)
91+
@func_treelize(rise=True)
92+
def split(ary, *args, **kwargs):
93+
return np.split(ary, *args, **kwargs)
94+
95+
96+
@doc_from(np.zeros)
97+
@func_treelize()
98+
def zeros(shape, *args, **kwargs):
99+
return np.zeros(shape, *args, **kwargs)
100+
101+
102+
@doc_from(np.ones)
103+
@func_treelize()
104+
def ones(shape, *args, **kwargs):
105+
return np.ones(shape, *args, **kwargs)

0 commit comments

Comments
 (0)