Skip to content

Commit 82220e5

Browse files
committed
dev(hansbug): add operation show
1 parent 71a6531 commit 82220e5

File tree

3 files changed

+101
-0
lines changed

3 files changed

+101
-0
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Customized Operations For Different Fields
2+
============================================
3+
4+
Here is another example of the custom operations implemented \
5+
with both native torch API and treetensor API.
6+
7+
.. literalinclude:: operation.demo.py
8+
:language: python
9+
:linenos:
10+
11+
The output should be like below, and all the assertions can \
12+
be passed.
13+
14+
.. literalinclude:: operation.demo.py.txt
15+
:language: text
16+
:linenos:
17+
18+
The implement with treetensor API is much simpler and clearer.
19+
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import copy
2+
3+
import torch
4+
5+
import treetensor.torch as ttorch
6+
7+
T, B = 3, 4
8+
9+
10+
def with_nativetensor(batch_):
11+
mean_b_list = []
12+
even_index_a_list = []
13+
for i in range(len(batch_)):
14+
for k, v in batch_[i].items():
15+
if k == 'a':
16+
v = v.float()
17+
even_index_a_list.append(v[::2])
18+
elif k == 'b':
19+
v = v.float()
20+
transformed_v = torch.pow(v, 2) + 1.0
21+
mean_b_list.append(transformed_v.mean())
22+
elif k == 'c':
23+
for k1, v1 in v.items():
24+
if k1 == 'd':
25+
v1 = v1.float()
26+
else:
27+
print('ignore keys: {}'.format(k1))
28+
else:
29+
print('ignore keys: {}'.format(k))
30+
for i in range(len(batch_)):
31+
for k in batch_[i].keys():
32+
if k == 'd':
33+
batch_[i][k]['noise'] = torch.randn(size=(3, 4, 5))
34+
35+
mean_b = sum(mean_b_list) / len(mean_b_list)
36+
even_index_a = torch.stack(even_index_a_list, dim=0)
37+
return batch_, mean_b, even_index_a
38+
39+
40+
def with_treetensor(batch_):
41+
batch_ = [ttorch.tensor(b) for b in batch_]
42+
batch_ = ttorch.stack(batch_)
43+
batch_ = batch_.float()
44+
batch_.b = ttorch.pow(batch_.b, 2) + 1.0
45+
batch_.c.noise = ttorch.randn(size=(B, 3, 4, 5))
46+
mean_b = batch_.b.mean()
47+
even_index_a = batch_.a[:, ::2]
48+
batch_ = ttorch.split(batch_, split_size_or_sections=1, dim=0)
49+
return batch_, mean_b, even_index_a
50+
51+
52+
def get_data():
53+
return {
54+
'a': torch.rand(size=(T, 8)),
55+
'b': torch.rand(size=(6,)),
56+
'c': {
57+
'd': torch.randint(0, 10, size=(1,))
58+
}
59+
}
60+
61+
62+
if __name__ == "__main__":
63+
batch = [get_data() for _ in range(B)]
64+
batch0, mean0, even_index_a0 = with_nativetensor(copy.deepcopy(batch))
65+
batch1, mean1, even_index_a1 = with_treetensor(copy.deepcopy(batch))
66+
print(batch0)
67+
print('\n\n')
68+
print(batch1)
69+
70+
assert torch.abs(mean0 - mean1) < 1e-6
71+
print('mean0 & mean1:', mean0, mean1)
72+
print('\n')
73+
74+
assert torch.abs((even_index_a0 - even_index_a1).max()) < 1e-6
75+
print('even_index_a0:', even_index_a0)
76+
print('even_index_a1:', even_index_a1)
77+
78+
assert len(batch0) == B
79+
assert len(batch1) == B
80+
assert isinstance(batch1[0], ttorch.Tensor)
81+
print(batch1[0].shape)

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ module.
2020
:caption: Best Practice
2121

2222
best_practice/stack/index
23+
best_practice/operation/index
2324

2425
.. toctree::
2526
:maxdepth: 2

0 commit comments

Comments
 (0)