Skip to content

Commit 7c5c7c7

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
SpecDB: Add OutTensor specs for add.Tensor & add.Scalar (#5)
Summary: Pull Request resolved: #5 Added the out tensor spec for: - add.Tensor: promoted type of inputs must be castable to out dtype - add.Scalar: promoted type of inputs must be equal to out dtype Prompted by Jarvis's adoption of FACTO-based testing on D59247125 Reviewed By: zonglinpengmeta Differential Revision: D59402158 fbshipit-source-id: 784471e2b5494dbf82b5003c259ce75e9e63a323
1 parent 221e7e7 commit 7c5c7c7

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

specdb/db.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,16 @@
341341
),
342342
],
343343
outspec=[
344-
OutArg(ArgType.Tensor),
344+
OutArg(
345+
ArgType.Tensor,
346+
constraints=[
347+
cp.Dtype.In(
348+
lambda deps: dt.can_cast_from(
349+
torch.promote_types(deps[0].dtype, deps[1].dtype)
350+
)
351+
),
352+
],
353+
),
345354
],
346355
),
347356
Spec(
@@ -373,7 +382,16 @@
373382
),
374383
],
375384
outspec=[
376-
OutArg(ArgType.Tensor),
385+
OutArg(
386+
ArgType.Tensor,
387+
constraints=[
388+
cp.Dtype.Eq(
389+
lambda deps: (
390+
fn.promote_type_with_scalar(deps[0].dtype, deps[1])
391+
)
392+
),
393+
],
394+
)
377395
],
378396
),
379397
Spec(

0 commit comments

Comments
 (0)