Skip to content

[API] paddle.slogdet 返回值规范化 #72505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e0d50e0
add infermeta
aquagull Apr 16, 2025
d84c9d6
update code
aquagull Apr 16, 2025
9d82610
update code
aquagull Apr 17, 2025
f2c1fe4
fix
aquagull Apr 17, 2025
fe46071
update
aquagull Apr 18, 2025
f32850f
update
aquagull Apr 18, 2025
8140e54
fix
aquagull Apr 18, 2025
287a0ef
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Apr 20, 2025
83196ea
update
aquagull Apr 22, 2025
657dcc8
update
aquagull Apr 22, 2025
0b26e4b
fix opinfersymbolicshape
aquagull Apr 23, 2025
d4f0726
fix grad
aquagull Apr 24, 2025
01f3597
update
aquagull Apr 26, 2025
6dddf9e
fix
aquagull Apr 26, 2025
c57f89a
fix
aquagull Apr 26, 2025
f1ff397
fix
aquagull Apr 26, 2025
02f75e5
empty
aquagull Apr 26, 2025
33ddaaa
fix
aquagull Apr 27, 2025
f27db1f
empty
aquagull Apr 27, 2025
ae3c884
fix
aquagull May 1, 2025
7bc0fc7
update
aquagull May 6, 2025
ce9ad8a
fix
aquagull May 6, 2025
17cb8cc
fix
aquagull May 17, 2025
3bbe340
fix python
aquagull May 17, 2025
6900484
empty commit
aquagull May 27, 2025
26ac20e
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jun 6, 2025
61d0b6d
add patch
aquagull Jun 9, 2025
14b9c3b
fix doc
aquagull Jun 11, 2025
9f03fee
Merge branch 'develop' into slogdet
aquagull Jun 12, 2025
894d187
fix 0-size
aquagull Jun 13, 2025
8dcade8
fix test_determinant_op
aquagull Jun 13, 2025
adc0689
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jun 18, 2025
e427789
doc
aquagull Jun 20, 2025
3091017
doc
aquagull Jun 20, 2025
61003ef
fix
aquagull Jun 20, 2025
0a26a26
update patch yaml
aquagull Jun 23, 2025
76c6e42
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jun 23, 2025
be079c0
Merge branch 'develop' into slogdet
aquagull Jul 6, 2025
2745c98
code-style
aquagull Jul 8, 2025
9a2afab
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jul 15, 2025
c83b8d4
Update op_compat.yaml
HydrogenSulfate Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3325,15 +3325,17 @@ bool SlogdetOpInferSymbolicShape(
"greater than or equal to 2."));
infer_context->AddEqualCstr(x_shape[x_shape_size - 1],
x_shape[x_shape_size - 2]);
std::vector<symbol::DimExpr> out_shape = {2};
size_t additional_dims = x_shape.size() - 2;
for (size_t i = 0; i < additional_dims; i++) {
out_shape.push_back(x_shape[i]);
std::vector<symbol::DimExpr> out_dims;
if (x_shape_size > 2) {
out_dims.assign(x_shape.begin(), x_shape.end() - 2);
}
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)});

return true;
}

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pir/serialize_deserialize/patch/2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,9 @@ op_patches:
type : pir::ArrayAttribute
data :
- type: pir::Int64Attribute
- op_name : pd_op.slogdet
actions:
- action : add_output
object : 1
type : pir::DenseTensorType
data : [pir::Float32Type,[-1],"NCHW",[],0]
26 changes: 26 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6339,6 +6339,32 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
scale->set_dtype(x.dtype());
}

void SlogdetInferMeta(const MetaTensor& x,
MetaTensor* sign,
MetaTensor* logdet) {
DDim x_dims = x.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_GE(rank,
2,
errors::InvalidArgument(
"Input(X) should be at least a 2-D tensor, but got %u.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
x_dims[rank - 1],
x_dims[rank - 2],
errors::InvalidArgument("the input matrix should be square matrix."));
auto x_dtype = x.dtype();
auto x_layout = x.layout();
DDim out_dims = slice_ddim(x_dims, 0, rank - 2);
sign->set_dtype(x_dtype);
sign->set_layout(x_layout);
sign->set_dims(out_dims);

logdet->set_dtype(dtype::ToReal(x_dtype));
logdet->set_layout(x_layout);
logdet->set_dims(out_dims);
}

void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,6 +862,10 @@ void SvdInferMeta(const MetaTensor& x,
MetaTensor* s,
MetaTensor* vh);

void SlogdetInferMeta(const MetaTensor& x,
MetaTensor* sign,
MetaTensor* logdet);

void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
Expand Down
6 changes: 5 additions & 1 deletion paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ PD_REGISTER_KERNEL(slogdet_grad,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
phi::DataType real_dtype = phi::dtype::ToReal(kernel_key.dtype());
kernel->InputAt(2).SetDataType(real_dtype);
kernel->InputAt(4).SetDataType(real_dtype);
}
10 changes: 10 additions & 0 deletions paddle/phi/kernels/elementwise_multiply_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,14 @@ DenseTensor Multiply(const Context& dev_ctx,
return dense_out;
}

template <typename T, typename Context>
void Multiply(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
MetaTensor meta_out(out);
ElementwiseInferMeta(x, y, &meta_out);
MultiplyKernel<T, Context>(dev_ctx, x, y, out);
}

} // namespace phi
6 changes: 5 additions & 1 deletion paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ PD_REGISTER_KERNEL(slogdet_grad,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
phi::DataType real_dtype = phi::dtype::ToReal(kernel_key.dtype());
kernel->InputAt(2).SetDataType(real_dtype);
kernel->InputAt(4).SetDataType(real_dtype);
}
Loading
Loading