Skip to content

Commit 29e191b

Browse files
authored
Zero len addressed. some cleanups (#108)
* Zero len addressed. some cleanups * using 64bit long in MKL interface
1 parent 335b6db commit 29e191b

8 files changed

+87
-304
lines changed

dpnp/backend/custom_kernels_reduction.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@ class custom_sum_c_kernel;
3737
template <typename _DataType>
3838
void custom_sum_c(void* array1_in, void* result1, size_t size)
3939
{
40+
if (!size)
41+
{
42+
return;
43+
}
44+
4045
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
4146
_DataType* result = reinterpret_cast<_DataType*>(result1);
4247

@@ -170,6 +175,11 @@ class custom_prod_c_kernel;
170175
template <typename _DataType>
171176
void custom_prod_c(void* array1_in, void* result1, size_t size)
172177
{
178+
if (!size)
179+
{
180+
return;
181+
}
182+
173183
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
174184
_DataType* result = reinterpret_cast<_DataType*>(result1);
175185

dpnp/backend/custom_kernels_statistics.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,11 +191,16 @@ void custom_mean_c(void* array1_in, void* result1, const size_t* shape, size_t n
191191
size *= shape[i];
192192
}
193193

194+
if (!size)
195+
{
196+
return;
197+
}
198+
194199
_DataType* sum = reinterpret_cast<_DataType*>(dpnp_memory_alloc_c(1 * sizeof(_DataType)));
195200

196201
custom_sum_c<_DataType>(array1_in, sum, size);
197202

198-
result[0] = (_ResultType)(sum[0]) / size;
203+
result[0] = static_cast<_ResultType>(sum[0]) / static_cast<_ResultType>(size);
199204

200205
dpnp_memory_free_c(sum);
201206

@@ -232,7 +237,7 @@ void custom_median_c(void* array1_in, void* result1, const size_t* shape, size_t
232237

233238
if (size % 2 == 0)
234239
{
235-
result[0] = (_ResultType)(sorted[size / 2] + sorted[size / 2 - 1]) / 2;
240+
result[0] = static_cast<_ResultType>(sorted[size / 2] + sorted[size / 2 - 1]) / 2;
236241
}
237242
else
238243
{
@@ -352,7 +357,7 @@ void custom_var_c(
352357
{
353358
size_t i = global_id[0]; /*for (size_t i = 0; i < size; ++i)*/
354359
{
355-
_ResultType deviation = (_ResultType)array1[i] - mean_val;
360+
_ResultType deviation = static_cast<_ResultType>(array1[i]) - mean_val;
356361
squared_deviations[i] = deviation * deviation;
357362
}
358363
}); /* parallel_for */
@@ -363,7 +368,7 @@ void custom_var_c(
363368
custom_mean_c<_ResultType, _ResultType>(squared_deviations, mean, shape, ndim, axis, naxis);
364369
mean_val = mean[0];
365370

366-
result[0] = mean_val * size / (size - ddof);
371+
result[0] = mean_val * size / static_cast<_ResultType>(size - ddof);
367372

368373
dpnp_memory_free_c(mean);
369374
dpnp_memory_free_c(squared_deviations);

dpnp/backend/mkl_wrap_blas1.cpp

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -35,32 +35,25 @@ namespace mkl_blas = oneapi::mkl::blas;
3535
template <typename _DataType>
3636
void mkl_blas_dot_c(void* array1_in, void* array2_in, void* result1, size_t size)
3737
{
38+
if (!size)
39+
{
40+
return;
41+
}
42+
3843
cl::sycl::event status;
3944

4045
_DataType* array_1 = reinterpret_cast<_DataType*>(array1_in);
4146
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
4247
_DataType* result = reinterpret_cast<_DataType*>(result1);
4348

44-
try
45-
{
46-
status = mkl_blas::dot(DPNP_QUEUE,
47-
size,
48-
array_1,
49-
1, // array_1 stride
50-
array_2,
51-
1, // array_2 stride
52-
result);
53-
}
54-
catch (cl::sycl::exception const& e)
55-
{
56-
std::cerr << "Caught synchronous SYCL exception during mkl_blas_dot_c():\n"
57-
<< e.what() << "\nOpenCL status: " << e.get_cl_code() << std::endl;
58-
}
59-
49+
status = mkl_blas::dot(DPNP_QUEUE,
50+
size,
51+
array_1,
52+
1, // array_1 stride
53+
array_2,
54+
1, // array_2 stride
55+
result);
6056
status.wait();
61-
#if 0
62-
std::cout << "mkl_blas_dot_c res = " << result[0] << std::endl;
63-
#endif
6457
}
6558

6659
template void mkl_blas_dot_c<float>(void* array1_in, void* array2_in, void* result1, size_t size);

dpnp/backend/mkl_wrap_blas3.cpp

Lines changed: 20 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ namespace mkl_blas = oneapi::mkl::blas;
3535
template <typename _DataType>
3636
void dpnp_blas_gemm_c(void* array1_in, void* array2_in, void* result1, size_t size_m, size_t size_n, size_t size_k)
3737
{
38+
if (!size_m || !size_n || !size_k)
39+
{
40+
return;
41+
}
42+
3843
cl::sycl::event status;
44+
3945
// using std::max for these ldx variables is required by MKL
4046
const std::int64_t lda = std::max<size_t>(1UL, size_k); // First dimensions of array_1
4147
const std::int64_t ldb = std::max<size_t>(1UL, size_n); // First dimensions of array_2
@@ -45,85 +51,22 @@ void dpnp_blas_gemm_c(void* array1_in, void* array2_in, void* result1, size_t si
4551
_DataType* array_2 = reinterpret_cast<_DataType*>(array2_in);
4652
_DataType* result = reinterpret_cast<_DataType*>(result1);
4753

48-
#if 0
49-
std::cout << ">>>>>>>>>>>>>>>>MKL dpnp_blas_gemm_c parameters:"
50-
<< "\n"
51-
<< "lda=" << lda << "\n"
52-
<< "ldb=" << ldb << "\n"
53-
<< "ldc=" << ldc << "\n"
54-
<< "size_m=" << size_m << "\n"
55-
<< "size_n=" << size_n << "\n"
56-
<< "size_k=" << size_k << "\n"
57-
<< "alfa=" << _DataType(1) << "\n"
58-
<< "beta=" << _DataType(0) << "\n"
59-
<< std::endl;
60-
61-
std::cout << "array_1\n";
62-
for (size_t it1 = 0; it1 < size_m; ++it1)
63-
{
64-
for (size_t it2 = 0; it2 < size_k; ++it2)
65-
{
66-
std::cout << " , " << array_1[it1*size_k + it2];
67-
}
68-
std::cout << std::endl;
69-
}
70-
71-
std::cout << "array_2\n";
72-
for (size_t it1 = 0; it1 < size_k; ++it1)
73-
{
74-
for (size_t it2 = 0; it2 < size_n; ++it2)
75-
{
76-
std::cout << " , " << array_2[it1*size_n + it2];
77-
}
78-
std::cout << std::endl;
79-
}
80-
81-
std::cout << "result_1 before\n";
82-
for (size_t it1 = 0; it1 < size_m; ++it1)
83-
{
84-
for (size_t it2 = 0; it2 < size_n; ++it2)
85-
{
86-
result[it1*size_n + it2] = 0;
87-
std::cout << " , " << result[it1*size_n + it2];
88-
}
89-
std::cout << std::endl;
90-
}
91-
#endif
92-
try
93-
{
94-
status = mkl_blas::gemm(DPNP_QUEUE,
95-
oneapi::mkl::transpose::nontrans,
96-
oneapi::mkl::transpose::nontrans,
97-
size_n,
98-
size_m,
99-
size_k,
100-
_DataType(1),
101-
array_2,
102-
ldb,
103-
array_1,
104-
lda,
105-
_DataType(0),
106-
result,
107-
ldc);
108-
}
109-
catch (cl::sycl::exception const& e)
110-
{
111-
std::cerr << "Caught synchronous SYCL exception during dpnp_blas_gemm_c():\n"
112-
<< e.what() << "\nOpenCL status: " << e.get_cl_code() << std::endl;
113-
}
54+
status = mkl_blas::gemm(DPNP_QUEUE,
55+
oneapi::mkl::transpose::nontrans,
56+
oneapi::mkl::transpose::nontrans,
57+
size_n,
58+
size_m,
59+
size_k,
60+
_DataType(1),
61+
array_2,
62+
ldb,
63+
array_1,
64+
lda,
65+
_DataType(0),
66+
result,
67+
ldc);
11468

11569
status.wait();
116-
#if 0
117-
std::cout << "result_1 after\n";
118-
for (size_t it1 = 0; it1 < size_m; ++it1)
119-
{
120-
for (size_t it2 = 0; it2 < size_n; ++it2)
121-
{
122-
std::cout << " , " << result[it1*size_n + it2];
123-
}
124-
std::cout << std::endl;
125-
}
126-
#endif
12770
}
12871

12972
template void dpnp_blas_gemm_c<float>(

dpnp/backend/mkl_wrap_lapack.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ namespace mkl_lapack = oneapi::mkl::lapack;
3535
template <typename _DataType>
3636
void mkl_lapack_syevd_c(void* array_in, void* result1, size_t size)
3737
{
38+
if (!size)
39+
{
40+
return;
41+
}
42+
3843
cl::sycl::event status;
3944

4045
_DataType* array = reinterpret_cast<_DataType*>(array_in);

dpnp/backend/mkl_wrap_rng.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ namespace mkl_rng = oneapi::mkl::rng;
3838
template <typename _DataType>
3939
void mkl_rng_gaussian(void* result, size_t size)
4040
{
41+
if (!size)
42+
{
43+
return;
44+
}
45+
4146
_DataType* result1 = reinterpret_cast<_DataType*>(result);
4247

4348
// TODO:
@@ -50,22 +55,20 @@ void mkl_rng_gaussian(void* result, size_t size)
5055
const _DataType stddev = _DataType(1.0);
5156

5257
mkl_rng::gaussian<_DataType> distribution(mean, stddev);
53-
try
54-
{
55-
// perform generation
56-
mkl_rng::generate(distribution, engine, size, result1);
57-
DPNP_QUEUE.wait_and_throw();
58-
}
59-
catch (cl::sycl::exception const& e)
60-
{
61-
std::cerr << "Caught synchronous SYCL exception during mkl_rng_gaussian():\n"
62-
<< e.what() << "\nOpenCL status: " << e.get_cl_code() << std::endl;
63-
}
58+
// perform generation
59+
mkl_rng::generate(distribution, engine, size, result1);
60+
61+
DPNP_QUEUE.wait();
6462
}
6563

6664
template <typename _DataType>
6765
void mkl_rng_uniform(void* result, long low, long high, size_t size)
6866
{
67+
if (!size)
68+
{
69+
return;
70+
}
71+
6972
_DataType* result1 = reinterpret_cast<_DataType*>(result);
7073

7174
// TODO:
@@ -80,17 +83,10 @@ void mkl_rng_uniform(void* result, long low, long high, size_t size)
8083
const _DataType b = (_DataType(high));
8184

8285
mkl_rng::uniform<_DataType> distribution(a, b);
83-
try
84-
{
85-
// perform generation
86-
mkl_rng::generate(distribution, engine, size, result1);
87-
DPNP_QUEUE.wait_and_throw();
88-
}
89-
catch (cl::sycl::exception const& e)
90-
{
91-
std::cerr << "Caught synchronous SYCL exception during mkl_rng_uniform_mt19937():\n"
92-
<< e.what() << "\nOpenCL status: " << e.get_cl_code() << std::endl;
93-
}
86+
// perform generation
87+
mkl_rng::generate(distribution, engine, size, result1);
88+
89+
DPNP_QUEUE.wait();
9490
}
9591

9692
template void mkl_rng_gaussian<double>(void* result, size_t size);

0 commit comments

Comments
 (0)