@@ -187,27 +187,29 @@ def add_two(x, y):
187
187
如下代码样例中的 ` if label is not None ` , 此判断只依赖于 ` label ` 是否为 ` None ` (存在性),并不依赖 ` label ` 的 Tensor 值(数值性),因此属于** 不依赖 Tensor 的控制流** 。
188
188
189
189
``` python
190
+ from paddle.jit import to_static
191
+
190
192
def not_depend_tensor_if (x , label = None ):
191
193
out = x + 1
192
194
if label is not None : # <----- python bool 类型
193
195
out = paddle.nn.functional.cross_entropy(out, label)
194
196
return out
195
197
196
- print (to_static(not_depend_tensor_ifw ).code)
198
+ print (to_static(not_depend_tensor_if ).code)
197
199
# 转写后的代码:
198
200
"""
199
201
def not_depend_tensor_if(x, label=None):
200
202
out = x + 1
201
203
202
- def true_fn_1 (label, out): # true 分支
204
+ def true_fn_0 (label, out): # true 分支
203
205
out = paddle.nn.functional.cross_entropy(out, label)
204
206
return out
205
207
206
- def false_fn_1 (out): # false 分支
208
+ def false_fn_0 (out): # false 分支
207
209
return out
208
210
209
- out = paddle.jit.dy2static.convert_ifelse(label is not None, true_fn_1 ,
210
- false_fn_1 , (label, out), (out,), (out,))
211
+ out = paddle.jit.dy2static.convert_ifelse(label is not None, true_fn_0 ,
212
+ false_fn_0 , (label, out), (out,), (out,))
211
213
212
214
return out
213
215
"""
@@ -219,6 +221,8 @@ def not_depend_tensor_if(x, label=None):
219
221
如下代码样例中的 ` if paddle.mean(x) > 5 ` , 此判断直接依赖 ` paddle.mean(x) ` 返回的 Tensor 值(数值性),因此属于** 依赖 Tensor 的控制流** 。
220
222
221
223
``` python
224
+ from paddle.jit import to_static
225
+
222
226
def depend_tensor_if (x ):
223
227
if paddle.mean(x) > 5 .: # <---- Bool Tensor 类型
224
228
out = x - 1
@@ -230,7 +234,7 @@ print(to_static(depend_tensor_if).code)
230
234
# 转写后的代码:
231
235
"""
232
236
def depend_tensor_if(x):
233
- out = paddle.jit.dy2static.data_layer_not_check(name='out ', shape=[-1],
237
+ out = paddle.jit.dy2static.data_layer_not_check(name='out_0 ', shape=[-1],
234
238
dtype='float32')
235
239
236
240
def true_fn_0(x): # true 分支
@@ -280,6 +284,8 @@ def convert_ifelse(pred, true_fn, false_fn, true_args, false_args, return_vars):
280
284
如下代码样例中的 ` while a < 10 ` , 此循环条件中的 ` a ` 是一个 ` int ` 类型,并不是 Tensor 类型,因此属于** 不依赖 Tensor 的控制流** 。
281
285
282
286
``` python
287
+ from paddle.jit import to_static
288
+
283
289
def not_depend_tensor_while (x ):
284
290
a = 1
285
291
@@ -315,10 +321,12 @@ def not_depend_tensor_while(x):
315
321
如下代码样例中的 ` for i in range(bs) ` , 此循环条件中的 ` bs ` 是一个 ` paddle.shape ` 返回的 Tensor 类型,且将其 Tensor 值作为了循环的终止条件,因此属于** 依赖 Tensor 的控制流** 。
316
322
317
323
``` python
324
+ from paddle.jit import to_static
325
+
318
326
def depend_tensor_while (x ):
319
327
bs = paddle.shape(x)[0 ]
320
328
321
- for i in range (bs): # <---- bas is a Tensor
329
+ for i in range (bs): # <---- bs is a Tensor
322
330
x = x + 1
323
331
324
332
return x
0 commit comments