Skip to content

[API] paddle.slogdet 返回值规范化 #72505

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 41 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
e0d50e0
add infermeta
aquagull Apr 16, 2025
d84c9d6
update code
aquagull Apr 16, 2025
9d82610
update code
aquagull Apr 17, 2025
f2c1fe4
fix
aquagull Apr 17, 2025
fe46071
update
aquagull Apr 18, 2025
f32850f
update
aquagull Apr 18, 2025
8140e54
fix
aquagull Apr 18, 2025
287a0ef
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Apr 20, 2025
83196ea
update
aquagull Apr 22, 2025
657dcc8
update
aquagull Apr 22, 2025
0b26e4b
fix opinfersymbolicshape
aquagull Apr 23, 2025
d4f0726
fix grad
aquagull Apr 24, 2025
01f3597
update
aquagull Apr 26, 2025
6dddf9e
fix
aquagull Apr 26, 2025
c57f89a
fix
aquagull Apr 26, 2025
f1ff397
fix
aquagull Apr 26, 2025
02f75e5
empty
aquagull Apr 26, 2025
33ddaaa
fix
aquagull Apr 27, 2025
f27db1f
empty
aquagull Apr 27, 2025
ae3c884
fix
aquagull May 1, 2025
7bc0fc7
update
aquagull May 6, 2025
ce9ad8a
fix
aquagull May 6, 2025
17cb8cc
fix
aquagull May 17, 2025
3bbe340
fix python
aquagull May 17, 2025
6900484
empty commit
aquagull May 27, 2025
26ac20e
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jun 6, 2025
61d0b6d
add patch
aquagull Jun 9, 2025
14b9c3b
fix doc
aquagull Jun 11, 2025
9f03fee
Merge branch 'develop' into slogdet
aquagull Jun 12, 2025
894d187
fix 0-size
aquagull Jun 13, 2025
8dcade8
fix test_determinant_op
aquagull Jun 13, 2025
adc0689
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jun 18, 2025
e427789
doc
aquagull Jun 20, 2025
3091017
doc
aquagull Jun 20, 2025
61003ef
fix
aquagull Jun 20, 2025
0a26a26
update patch yaml
aquagull Jun 23, 2025
76c6e42
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jun 23, 2025
be079c0
Merge branch 'develop' into slogdet
aquagull Jul 6, 2025
2745c98
code-style
aquagull Jul 8, 2025
9a2afab
Merge branch 'PaddlePaddle:develop' into slogdet
aquagull Jul 15, 2025
c83b8d4
Update op_compat.yaml
HydrogenSulfate Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3336,15 +3336,17 @@ bool SlogdetOpInferSymbolicShape(
"greater than or equal to 2."));
infer_context->AddEqualCstr(x_shape[x_shape_size - 1],
x_shape[x_shape_size - 2]);
std::vector<symbol::DimExpr> out_shape = {2};
size_t additional_dims = x_shape.size() - 2;
for (size_t i = 0; i < additional_dims; i++) {
out_shape.push_back(x_shape[i]);
std::vector<symbol::DimExpr> out_dims;
if (x_shape_size > 2) {
out_dims.assign(x_shape.begin(), x_shape.end() - 2);
}
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(out_shape)});
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)});
infer_context->SetShapeOrDataForValue(
op->result(1),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)});

return true;
}

Expand Down
26 changes: 26 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6192,6 +6192,32 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
scale->set_dtype(x.dtype());
}

void SlogdetInferMeta(const MetaTensor& x,
MetaTensor* sign,
MetaTensor* logdet) {
DDim x_dims = x.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_GE(rank,
2,
errors::InvalidArgument(
"Input(X) should be at least a 2-D tensor, but got %u.",
x_dims.size()));
PADDLE_ENFORCE_EQ(
x_dims[rank - 1],
x_dims[rank - 2],
errors::InvalidArgument("the input matrix should be square matrix."));
auto x_dtype = x.dtype();
auto x_layout = x.layout();
DDim out_dims = slice_ddim(x_dims, 0, rank - 2);
sign->set_dtype(x_dtype);
sign->set_layout(x_layout);
sign->set_dims(out_dims);

logdet->set_dtype(dtype::ToReal(x_dtype));
logdet->set_layout(x_layout);
logdet->set_dims(out_dims);
}

void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,10 @@ void SvdInferMeta(const MetaTensor& x,
MetaTensor* s,
MetaTensor* vh);

void SlogdetInferMeta(const MetaTensor& x,
MetaTensor* sign,
MetaTensor* logdet);

void TemporalShiftInferMeta(const MetaTensor& x,
int seg_num,
float shift_ratio,
Expand Down
6 changes: 5 additions & 1 deletion paddle/phi/kernels/cpu/slogdeterminant_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ PD_REGISTER_KERNEL(slogdet_grad,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
phi::DataType real_dtype = phi::dtype::ToReal(kernel_key.dtype());
kernel->InputAt(2).SetDataType(real_dtype);
kernel->InputAt(4).SetDataType(real_dtype);
}
10 changes: 10 additions & 0 deletions paddle/phi/kernels/elementwise_multiply_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,14 @@ DenseTensor Multiply(const Context& dev_ctx,
return dense_out;
}

template <typename T, typename Context>
void Multiply(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
MetaTensor meta_out(out);
ElementwiseInferMeta(x, y, &meta_out);
MultiplyKernel<T, Context>(dev_ctx, x, y, out);
}

} // namespace phi
6 changes: 5 additions & 1 deletion paddle/phi/kernels/gpu/slogdeterminant_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,8 @@ PD_REGISTER_KERNEL(slogdet_grad,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
phi::DataType real_dtype = phi::dtype::ToReal(kernel_key.dtype());
kernel->InputAt(2).SetDataType(real_dtype);
kernel->InputAt(4).SetDataType(real_dtype);
}
175 changes: 132 additions & 43 deletions paddle/phi/kernels/gpu/slogdeterminant_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,93 @@ T sign(T det, T modulus) {
return det / modulus;
}

template <typename T>
__global__ void GetSlogDetFromLU(const T* lu_data,
const int* ipiv,
int64_t n,
int64_t batch_size,
T* sign_data,
T* logdet_data) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < batch_size) {
int offset_lu = idx * n * n;
int offset_ipiv = idx * n;
T det_val = T(1.0);
for (int i = 0; i < n; i++) {
det_val *= lu_data[offset_lu + i * n + i];
if (ipiv[offset_ipiv + i] != i + 1) {
det_val = -det_val;
}
}
T abs_det = abs(det_val);
sign_data[idx] = static_cast<T>((T(0) < det_val) - (det_val < T(0)));
logdet_data[idx] = log(abs_det);
}
}

template <typename T, typename Context>
struct SlogDeterminantFunctor {
void operator()(const Context& dev_ctx,
const DenseTensor& input,
int64_t rank,
int64_t batch_count,
DenseTensor* output) {
DenseTensor* sign,
DenseTensor* logdet) {
#ifndef PADDLE_WITH_HIP
phi::Allocator::AllocationPtr tmp_gpu_mat_data;
const T* gpu_mat = input.data<T>();
tmp_gpu_mat_data = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
input.numel() * sizeof(T),
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
memory_utils::Copy(dev_ctx.GetPlace(),
tmp_gpu_mat_data->ptr(),
dev_ctx.GetPlace(),
input.data(),
input.numel() * sizeof(T),
dev_ctx.stream());
gpu_mat = reinterpret_cast<const T*>(tmp_gpu_mat_data->ptr());

std::vector<const T*> cpu_ptrs(batch_count);
for (int i = 0; i < batch_count; ++i) {
cpu_ptrs[i] = gpu_mat + i * rank * rank;
}

// num_ints is for pivot (rank * batch_count) and info (batch_count)
int num_ints = batch_count * (rank + 1);
size_t total_bytes = batch_count * sizeof(T*) + num_ints * sizeof(int);
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
dev_ctx.GetPlace(),
total_bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
memory_utils::Copy(dev_ctx.GetPlace(),
tmp_gpu_ptrs_data->ptr(),
phi::CPUPlace(),
static_cast<void*>(cpu_ptrs.data()),
cpu_ptrs.size() * sizeof(T*),
dev_ctx.stream());

T** gpu_mat_ptr = reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr());
int* gpu_info_ptr = reinterpret_cast<int*>(gpu_mat_ptr + cpu_ptrs.size());
int* pivot_data = gpu_info_ptr + batch_count;

auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
// This function performs the LU factorization of each matrix A by the
// equation P * A = L * U. L and U are written back to original matrix A,
// and diagonal elements of L are discarded.
blas.BatchedGETRF(rank, gpu_mat_ptr, pivot_data, gpu_info_ptr, batch_count);
T* sign_data = dev_ctx.template Alloc<T>(sign);
T* logdet_data = dev_ctx.template Alloc<T>(logdet);
int block_size = std::min(256, dev_ctx.GetMaxThreadsPerBlock());
dim3 dim_block(block_size);
dim3 num_blocks((batch_count + block_size - 1) / block_size);
GetSlogDetFromLU<T><<<num_blocks, dim_block>>>(
gpu_mat, pivot_data, rank, batch_count, sign_data, logdet_data);
Comment on lines +83 to +132
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里把原先的实数和复数分支合并在一起,没问题吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一处是实数分支增加gpu实现,没有涉及复数

#else
std::vector<T> input_vec;
std::vector<T> sign_vec;
std::vector<T> log_vec;
std::vector<T> output_vec;
DDim out_dims = sign->dims();
phi::TensorToVector(input, dev_ctx, &input_vec);
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
Expand All @@ -69,42 +145,56 @@ struct SlogDeterminantFunctor {
VLOG(2) << "det value: " << matrix.determinant();
VLOG(2) << "matrix val: " << matrix;
auto det_val = matrix.determinant();
sign_vec.push_back(sign(det_val));
sign_vec.push_back(phi::sign(det_val));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

phi::sign是哪个函数?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

namespace phi {

// T is not complex
template <typename T>
T sign(T val) {
  return static_cast<T>(T(0) < val) - (val < T(0));
}

// T is complex
template <typename T>
T sign(T det, T modulus) {
  return det / modulus;
}

当前文件下定义的

det_val >= 0
? log_vec.push_back(std::log(det_val))
: log_vec.push_back(std::log(std::abs(
det_val))); // for computing log value of a negative value.
}
// merge sign_vec and log_vec as final output_vec
output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end());
output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end());
phi::TensorFromVector(output_vec, dev_ctx, output);
phi::TensorFromVector(sign_vec, dev_ctx, sign);
phi::TensorFromVector(log_vec, dev_ctx, logdet);
if (out_dims == common::make_ddim({})) {
// TensorFromVector Converting inputTensor dimensions from () (scalar) to
// (1,)
sign->Resize(out_dims);
logdet->Resize(out_dims);
}
#endif
}
};

template <typename T>
__global__ void GetSlogDetFromLUComplex(const T* lu_data,
template <typename Complex_T, typename T>
__global__ void GetSlogDetFromLUComplex(const Complex_T* lu_data,
const int* ipiv,
int64_t n,
int64_t batch_size,
T* out_data) {
Complex_T* sign,
T* logdet) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < batch_size) {
int offset_lu = idx * n * n;
int offset_ipiv = idx * n;
T det_val = T(1.0, 0.0);
T negative = T(-1.0, 0.0);
Complex_T det_val = Complex_T(1.0, 0.0);
Complex_T negative = Complex_T(-1.0, 0.0);
for (int i = 0; i < n; ++i) {
det_val *= lu_data[offset_lu + i * n + i];
if (ipiv[offset_ipiv + i] != i + 1) {
det_val *= negative;
}
}
T abs_det = static_cast<T>(abs(det_val));
T sign = det_val / abs_det;
T log_abs_det = log(abs_det);
out_data[idx] = sign;
out_data[idx + batch_size] = log_abs_det;
T abs_det = abs(det_val);
T epsilon = std::numeric_limits<T>::epsilon();

if (abs_det <= epsilon) {
sign[idx] = Complex_T(0.0, 0.0);
logdet[idx] = -std::numeric_limits<T>::infinity();
} else {
Complex_T abs_det_complex = static_cast<Complex_T>(abs_det);
Complex_T s = det_val / abs_det_complex;
T log_abs_det = log(abs_det);
sign[idx] = s;
logdet[idx] = log_abs_det;
}
}
}

Expand All @@ -114,7 +204,8 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> {
const DenseTensor& input,
int64_t rank,
int64_t batch_count,
DenseTensor* output) {
DenseTensor* sign,
DenseTensor* logdet) {
#ifndef PADDLE_WITH_HIP
phi::Allocator::AllocationPtr tmp_gpu_mat_data;
const phi::dtype::complex<T>* gpu_mat =
Expand Down Expand Up @@ -164,20 +255,22 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> {
// equation P * A = L * U. L and U are written back to original matrix A,
// and diagonal elements of L are discarded.
blas.BatchedGETRF(rank, gpu_mat_ptr, pivot_data, gpu_info_ptr, batch_count);
phi::dtype::complex<T>* out_data =
dev_ctx.template Alloc<phi::dtype::complex<T>>(output);
phi::dtype::complex<T>* sign_data =
dev_ctx.template Alloc<phi::dtype::complex<T>>(sign);
T* logdet_data = dev_ctx.template Alloc<T>(logdet);
int block_size = std::min(256, dev_ctx.GetMaxThreadsPerBlock());
dim3 dim_block(block_size);
dim3 num_blocks((batch_count + block_size - 1) / block_size);
GetSlogDetFromLUComplex<phi::dtype::complex<T>><<<num_blocks, dim_block>>>(
gpu_mat, pivot_data, rank, batch_count, out_data);
GetSlogDetFromLUComplex<phi::dtype::complex<T>, T>
<<<num_blocks, dim_block>>>(
gpu_mat, pivot_data, rank, batch_count, sign_data, logdet_data);
#else
using MatrixType =
Eigen::Matrix<std::complex<T>, Eigen::Dynamic, Eigen::Dynamic>;
std::vector<phi::dtype::complex<T>> input_vec;
std::vector<phi::dtype::complex<T>> sign_vec;
std::vector<phi::dtype::complex<T>> log_vec;
std::vector<phi::dtype::complex<T>> output_vec;
DDim out_dims = sign->dims();
phi::TensorToVector(input, dev_ctx, &input_vec);
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel
auto begin_iter = input_vec.begin() + i * rank * rank;
Expand All @@ -196,27 +289,31 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> {
std::complex<T> det_val = matrix.determinant();
T abs_det_val = std::abs(det_val);
sign_vec.push_back(static_cast<phi::dtype::complex<T>>(
sign(det_val, static_cast<std::complex<T>>(abs_det_val))));
log_vec.push_back(
static_cast<phi::dtype::complex<T>>(std::log(abs_det_val)));
phi::sign(det_val, static_cast<std::complex<T>>(abs_det_val))));
log_vec.push_back(std::log(abs_det_val));
}
phi::TensorFromVector(sign_vec, dev_ctx, sign);
phi::TensorFromVector(log_vec, dev_ctx, logdet);
if (out_dims == common::make_ddim({})) {
// TensorFromVector Converting inputTensor dimensions from () (scalar) to
// (1,)
sign->Resize(out_dims);
logdet->Resize(out_dims);
}
// merge sign_vec and log_vec as final output_vec
output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end());
output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end());
phi::TensorFromVector(output_vec, dev_ctx, output);
#endif
}
};

template <typename T, typename Context>
void SlogDeterminantKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
DenseTensor* sign,
DenseTensor* logdet) {
auto input_dim = common::vectorize(x.dims());
auto input_dim_size = input_dim.size();

auto batch_count = detail::GetBatchCount(x.dims());
VLOG(2) << "input dim:" << x.dims();
VLOG(3) << "input dim:" << x.dims();
PADDLE_ENFORCE_GE(
input_dim_size,
2,
Expand All @@ -227,17 +324,9 @@ void SlogDeterminantKernel(const Context& dev_ctx,
input_dim[input_dim_size - 2],
errors::InvalidArgument("the input matrix should be square matrix."));
auto rank = input_dim[input_dim_size - 1]; // square matrix length
SlogDeterminantFunctor<T, Context>()(dev_ctx, x, rank, batch_count, out);
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2);
if (input_dim.size() == static_cast<size_t>(2)) {
// when input is a two-dimension matrix, The det value is a number.
output_dim_vec = {};
}
output_dim_vec.insert(output_dim_vec.begin(),
2); // make the output dims as same as numpy
auto output_dims = common::make_ddim(output_dim_vec);
out->Resize(output_dims);
VLOG(2) << "output dim:" << out->dims();
SlogDeterminantFunctor<T, Context>()(
dev_ctx, x, rank, batch_count, sign, logdet);
VLOG(3) << "sign dim:" << sign->dims();
}

} // namespace phi
Expand Down
Loading
Loading