-
Notifications
You must be signed in to change notification settings - Fork 5.8k
[API] paddle.slogdet
返回值规范化
#72505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
[API] paddle.slogdet
返回值规范化
#72505
Changes from 26 commits
e0d50e0
d84c9d6
9d82610
f2c1fe4
fe46071
f32850f
8140e54
287a0ef
83196ea
657dcc8
0b26e4b
d4f0726
01f3597
6dddf9e
c57f89a
f1ff397
02f75e5
33ddaaa
f27db1f
ae3c884
7bc0fc7
ce9ad8a
17cb8cc
3bbe340
6900484
26ac20e
61d0b6d
14b9c3b
9f03fee
894d187
8dcade8
adc0689
e427789
3091017
61003ef
0a26a26
76c6e42
be079c0
2745c98
9a2afab
c83b8d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,17 +43,93 @@ T sign(T det, T modulus) { | |
return det / modulus; | ||
} | ||
|
||
template <typename T> | ||
__global__ void GetSlogDetFromLU(const T* lu_data, | ||
const int* ipiv, | ||
int64_t n, | ||
int64_t batch_size, | ||
T* sign_data, | ||
T* logdet_data) { | ||
int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||
if (idx < batch_size) { | ||
int offset_lu = idx * n * n; | ||
int offset_ipiv = idx * n; | ||
T det_val = T(1.0); | ||
for (int i = 0; i < n; i++) { | ||
det_val *= lu_data[offset_lu + i * n + i]; | ||
if (ipiv[offset_ipiv + i] != i + 1) { | ||
det_val = -det_val; | ||
} | ||
} | ||
T abs_det = abs(det_val); | ||
sign_data[idx] = static_cast<T>((T(0) < det_val) - (det_val < T(0))); | ||
logdet_data[idx] = log(abs_det); | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
struct SlogDeterminantFunctor { | ||
void operator()(const Context& dev_ctx, | ||
const DenseTensor& input, | ||
int64_t rank, | ||
int64_t batch_count, | ||
DenseTensor* output) { | ||
DenseTensor* sign, | ||
DenseTensor* logdet) { | ||
#ifndef PADDLE_WITH_HIP | ||
phi::Allocator::AllocationPtr tmp_gpu_mat_data; | ||
const T* gpu_mat = input.data<T>(); | ||
tmp_gpu_mat_data = phi::memory_utils::Alloc( | ||
dev_ctx.GetPlace(), | ||
input.numel() * sizeof(T), | ||
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); | ||
memory_utils::Copy(dev_ctx.GetPlace(), | ||
tmp_gpu_mat_data->ptr(), | ||
dev_ctx.GetPlace(), | ||
input.data(), | ||
input.numel() * sizeof(T), | ||
dev_ctx.stream()); | ||
gpu_mat = reinterpret_cast<const T*>(tmp_gpu_mat_data->ptr()); | ||
|
||
std::vector<const T*> cpu_ptrs(batch_count); | ||
for (int i = 0; i < batch_count; ++i) { | ||
cpu_ptrs[i] = gpu_mat + i * rank * rank; | ||
} | ||
|
||
// num_ints is for pivot (rank * batch_count) and info (batch_count) | ||
int num_ints = batch_count * (rank + 1); | ||
size_t total_bytes = batch_count * sizeof(T*) + num_ints * sizeof(int); | ||
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc( | ||
dev_ctx.GetPlace(), | ||
total_bytes, | ||
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); | ||
memory_utils::Copy(dev_ctx.GetPlace(), | ||
tmp_gpu_ptrs_data->ptr(), | ||
phi::CPUPlace(), | ||
static_cast<void*>(cpu_ptrs.data()), | ||
cpu_ptrs.size() * sizeof(T*), | ||
dev_ctx.stream()); | ||
|
||
T** gpu_mat_ptr = reinterpret_cast<T**>(tmp_gpu_ptrs_data->ptr()); | ||
int* gpu_info_ptr = reinterpret_cast<int*>(gpu_mat_ptr + cpu_ptrs.size()); | ||
int* pivot_data = gpu_info_ptr + batch_count; | ||
|
||
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx); | ||
// This function performs the LU factorization of each matrix A by the | ||
// equation P * A = L * U. L and U are written back to original matrix A, | ||
// and diagonal elements of L are discarded. | ||
blas.BatchedGETRF(rank, gpu_mat_ptr, pivot_data, gpu_info_ptr, batch_count); | ||
T* sign_data = dev_ctx.template Alloc<T>(sign); | ||
T* logdet_data = dev_ctx.template Alloc<T>(logdet); | ||
int block_size = std::min(256, dev_ctx.GetMaxThreadsPerBlock()); | ||
dim3 dim_block(block_size); | ||
dim3 num_blocks((batch_count + block_size - 1) / block_size); | ||
GetSlogDetFromLU<T><<<num_blocks, dim_block>>>( | ||
gpu_mat, pivot_data, rank, batch_count, sign_data, logdet_data); | ||
Comment on lines
+83
to
+132
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里把原先的实数和复数分支合并在一起,没问题吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这一处是实数分支增加gpu实现,没有涉及复数 |
||
#else | ||
std::vector<T> input_vec; | ||
std::vector<T> sign_vec; | ||
std::vector<T> log_vec; | ||
std::vector<T> output_vec; | ||
DDim out_dims = sign->dims(); | ||
phi::TensorToVector(input, dev_ctx, &input_vec); | ||
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel | ||
auto begin_iter = input_vec.begin() + i * rank * rank; | ||
|
@@ -69,42 +145,56 @@ struct SlogDeterminantFunctor { | |
VLOG(2) << "det value: " << matrix.determinant(); | ||
VLOG(2) << "matrix val: " << matrix; | ||
auto det_val = matrix.determinant(); | ||
sign_vec.push_back(sign(det_val)); | ||
sign_vec.push_back(phi::sign(det_val)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. phi::sign是哪个函数? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. namespace phi {
// T is not complex
template <typename T>
T sign(T val) {
return static_cast<T>(T(0) < val) - (val < T(0));
}
// T is complex
template <typename T>
T sign(T det, T modulus) {
return det / modulus;
}
当前文件下定义的 |
||
det_val >= 0 | ||
? log_vec.push_back(std::log(det_val)) | ||
: log_vec.push_back(std::log(std::abs( | ||
det_val))); // for computing log value of a negative value. | ||
} | ||
// merge sign_vec and log_vec as final output_vec | ||
output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end()); | ||
output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end()); | ||
phi::TensorFromVector(output_vec, dev_ctx, output); | ||
phi::TensorFromVector(sign_vec, dev_ctx, sign); | ||
phi::TensorFromVector(log_vec, dev_ctx, logdet); | ||
if (out_dims == common::make_ddim({})) { | ||
// TensorFromVector Converting inputTensor dimensions from () (scalar) to | ||
// (1,) | ||
sign->Resize(out_dims); | ||
logdet->Resize(out_dims); | ||
} | ||
#endif | ||
} | ||
}; | ||
|
||
template <typename T> | ||
__global__ void GetSlogDetFromLUComplex(const T* lu_data, | ||
template <typename Complex_T, typename T> | ||
__global__ void GetSlogDetFromLUComplex(const Complex_T* lu_data, | ||
const int* ipiv, | ||
int64_t n, | ||
int64_t batch_size, | ||
T* out_data) { | ||
Complex_T* sign, | ||
T* logdet) { | ||
int idx = threadIdx.x + blockIdx.x * blockDim.x; | ||
if (idx < batch_size) { | ||
int offset_lu = idx * n * n; | ||
int offset_ipiv = idx * n; | ||
T det_val = T(1.0, 0.0); | ||
T negative = T(-1.0, 0.0); | ||
Complex_T det_val = Complex_T(1.0, 0.0); | ||
Complex_T negative = Complex_T(-1.0, 0.0); | ||
for (int i = 0; i < n; ++i) { | ||
det_val *= lu_data[offset_lu + i * n + i]; | ||
if (ipiv[offset_ipiv + i] != i + 1) { | ||
det_val *= negative; | ||
} | ||
} | ||
T abs_det = static_cast<T>(abs(det_val)); | ||
T sign = det_val / abs_det; | ||
T log_abs_det = log(abs_det); | ||
out_data[idx] = sign; | ||
out_data[idx + batch_size] = log_abs_det; | ||
T abs_det = abs(det_val); | ||
T epsilon = std::numeric_limits<T>::epsilon(); | ||
|
||
if (abs_det <= epsilon) { | ||
sign[idx] = Complex_T(0.0, 0.0); | ||
logdet[idx] = -std::numeric_limits<T>::infinity(); | ||
} else { | ||
Complex_T abs_det_complex = static_cast<Complex_T>(abs_det); | ||
Complex_T s = det_val / abs_det_complex; | ||
T log_abs_det = log(abs_det); | ||
sign[idx] = s; | ||
logdet[idx] = log_abs_det; | ||
} | ||
} | ||
} | ||
|
||
|
@@ -114,7 +204,8 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> { | |
const DenseTensor& input, | ||
int64_t rank, | ||
int64_t batch_count, | ||
DenseTensor* output) { | ||
DenseTensor* sign, | ||
DenseTensor* logdet) { | ||
#ifndef PADDLE_WITH_HIP | ||
phi::Allocator::AllocationPtr tmp_gpu_mat_data; | ||
const phi::dtype::complex<T>* gpu_mat = | ||
|
@@ -164,20 +255,22 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> { | |
// equation P * A = L * U. L and U are written back to original matrix A, | ||
// and diagonal elements of L are discarded. | ||
blas.BatchedGETRF(rank, gpu_mat_ptr, pivot_data, gpu_info_ptr, batch_count); | ||
phi::dtype::complex<T>* out_data = | ||
dev_ctx.template Alloc<phi::dtype::complex<T>>(output); | ||
phi::dtype::complex<T>* sign_data = | ||
dev_ctx.template Alloc<phi::dtype::complex<T>>(sign); | ||
T* logdet_data = dev_ctx.template Alloc<T>(logdet); | ||
int block_size = std::min(256, dev_ctx.GetMaxThreadsPerBlock()); | ||
dim3 dim_block(block_size); | ||
dim3 num_blocks((batch_count + block_size - 1) / block_size); | ||
GetSlogDetFromLUComplex<phi::dtype::complex<T>><<<num_blocks, dim_block>>>( | ||
gpu_mat, pivot_data, rank, batch_count, out_data); | ||
GetSlogDetFromLUComplex<phi::dtype::complex<T>, T> | ||
<<<num_blocks, dim_block>>>( | ||
gpu_mat, pivot_data, rank, batch_count, sign_data, logdet_data); | ||
#else | ||
using MatrixType = | ||
Eigen::Matrix<std::complex<T>, Eigen::Dynamic, Eigen::Dynamic>; | ||
std::vector<phi::dtype::complex<T>> input_vec; | ||
std::vector<phi::dtype::complex<T>> sign_vec; | ||
std::vector<phi::dtype::complex<T>> log_vec; | ||
std::vector<phi::dtype::complex<T>> output_vec; | ||
DDim out_dims = sign->dims(); | ||
phi::TensorToVector(input, dev_ctx, &input_vec); | ||
for (int64_t i = 0; i < batch_count; ++i) { // maybe can be parallel | ||
auto begin_iter = input_vec.begin() + i * rank * rank; | ||
|
@@ -196,27 +289,31 @@ struct SlogDeterminantFunctor<phi::dtype::complex<T>, Context> { | |
std::complex<T> det_val = matrix.determinant(); | ||
T abs_det_val = std::abs(det_val); | ||
sign_vec.push_back(static_cast<phi::dtype::complex<T>>( | ||
sign(det_val, static_cast<std::complex<T>>(abs_det_val)))); | ||
log_vec.push_back( | ||
static_cast<phi::dtype::complex<T>>(std::log(abs_det_val))); | ||
phi::sign(det_val, static_cast<std::complex<T>>(abs_det_val)))); | ||
log_vec.push_back(std::log(abs_det_val)); | ||
} | ||
phi::TensorFromVector(sign_vec, dev_ctx, sign); | ||
phi::TensorFromVector(log_vec, dev_ctx, logdet); | ||
if (out_dims == common::make_ddim({})) { | ||
// TensorFromVector Converting inputTensor dimensions from () (scalar) to | ||
// (1,) | ||
sign->Resize(out_dims); | ||
logdet->Resize(out_dims); | ||
} | ||
// merge sign_vec and log_vec as final output_vec | ||
output_vec.insert(output_vec.end(), sign_vec.begin(), sign_vec.end()); | ||
output_vec.insert(output_vec.end(), log_vec.begin(), log_vec.end()); | ||
phi::TensorFromVector(output_vec, dev_ctx, output); | ||
#endif | ||
} | ||
}; | ||
|
||
template <typename T, typename Context> | ||
void SlogDeterminantKernel(const Context& dev_ctx, | ||
const DenseTensor& x, | ||
DenseTensor* out) { | ||
DenseTensor* sign, | ||
DenseTensor* logdet) { | ||
auto input_dim = common::vectorize(x.dims()); | ||
auto input_dim_size = input_dim.size(); | ||
|
||
auto batch_count = detail::GetBatchCount(x.dims()); | ||
VLOG(2) << "input dim:" << x.dims(); | ||
VLOG(3) << "input dim:" << x.dims(); | ||
PADDLE_ENFORCE_GE( | ||
input_dim_size, | ||
2, | ||
|
@@ -227,17 +324,9 @@ void SlogDeterminantKernel(const Context& dev_ctx, | |
input_dim[input_dim_size - 2], | ||
errors::InvalidArgument("the input matrix should be square matrix.")); | ||
auto rank = input_dim[input_dim_size - 1]; // square matrix length | ||
SlogDeterminantFunctor<T, Context>()(dev_ctx, x, rank, batch_count, out); | ||
std::vector<int> output_dim_vec(input_dim.begin(), input_dim.end() - 2); | ||
if (input_dim.size() == static_cast<size_t>(2)) { | ||
// when input is a two-dimension matrix, The det value is a number. | ||
output_dim_vec = {}; | ||
} | ||
output_dim_vec.insert(output_dim_vec.begin(), | ||
2); // make the output dims as same as numpy | ||
auto output_dims = common::make_ddim(output_dim_vec); | ||
out->Resize(output_dims); | ||
VLOG(2) << "output dim:" << out->dims(); | ||
SlogDeterminantFunctor<T, Context>()( | ||
dev_ctx, x, rank, batch_count, sign, logdet); | ||
VLOG(3) << "sign dim:" << sign->dims(); | ||
} | ||
|
||
} // namespace phi | ||
|
Uh oh!
There was an error while loading. Please reload this page.