Skip to content

Commit c6d23b6

Browse files
committed
[Test] Add func_version parameter to specify the API version of Prune/Compress's function.
1 parent 19f5ba2 commit c6d23b6

13 files changed

+38
-114
lines changed

clients/benchmarks/client.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (c) 2022-2023 Advanced Micro Devices, Inc.
5+
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -456,6 +456,10 @@ try
456456
bool_switch(&arg.sparse_b)->default_value(false),
457457
"Structurted Sparsity Matrix B (A is Dense Matrix)")
458458

459+
("func_version",
460+
value<int32_t>(&arg.func_version)->default_value(1),
461+
"Specify the API version of Prune/Compress's function")
462+
459463
("log_function_name",
460464
bool_switch(&log_function_name)->default_value(false),
461465
"Function name precedes other itmes.")

clients/gtest/compress_gtest.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (c) 2022-2023 Advanced Micro Devices, Inc.
5+
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -113,6 +113,9 @@ namespace
113113

114114
if(strstr(arg.function, "_strided_batched") != nullptr)
115115
name << '_' << (arg.sparse_b ? arg.stride_b : arg.stride_a);
116+
117+
if(arg.func_version > 1)
118+
name << "_v" << arg.func_version;
116119
}
117120
return std::move(name);
118121
}

clients/gtest/compress_gtest.yaml

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,7 @@ Tests:
328328
alpha_beta: *alpha_beta_range
329329
sparse_b: [ true, false]
330330
transA_transB: *transA_transB_range
331-
332-
- name: compress2_256_8_16
333-
category: pre_checkin
334-
function:
335-
compress: *real_precisions_2b
336-
M: 256
337-
N: 8
338-
K: 16
339-
alpha_beta: *alpha_beta_range
340-
sparse_b: [ true, false]
341-
transA_transB: *transA_transB_range
331+
func_version: [1, 2]
342332

343333
- name: compress_16_256_8
344334
category: pre_checkin
@@ -350,18 +340,7 @@ Tests:
350340
alpha_beta: *alpha_beta_range
351341
sparse_b: [ true, false]
352342
transA_transB: *transA_transB_range
353-
354-
- name: compress2_16_256_8
355-
category: pre_checkin
356-
function:
357-
compress: *real_precisions_2b
358-
M: 16
359-
N: 256
360-
K: 8
361-
alpha_beta: *alpha_beta_range
362-
sparse_b: [ true, false]
363-
transA_transB: *transA_transB_range
364-
343+
func_version: [1, 2]
365344

366345
- name: compress_8_16_256
367346
category: pre_checkin

clients/gtest/compress_strided_batched_gtest.yaml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,7 @@ Tests:
4141
alpha_beta: *alpha_beta_range
4242
batch_count: [ 3 ]
4343
sparse_b: [ true, false]
44-
45-
- name: compress2_strided_batched_medium
46-
category: pre_checkin
47-
function:
48-
compress_strided_batched: *real_precisions_2b
49-
matrix_size: *strided_batched_medium_matrix_size_range
50-
transA_transB: *transA_transB_range
51-
alpha_beta: *alpha_beta_range
52-
batch_count: [ 3 ]
53-
sparse_b: [ true, false]
44+
func_version: [1, 2]
5445

5546
- name: compress_strided_batched_medium_alt
5647
category: pre_checkin

clients/gtest/compress_strided_batched_gtest_1b.yaml

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,7 @@ Tests:
4141
alpha_beta: *alpha_beta_range
4242
batch_count: [ 3 ]
4343
sparse_b: [ true, false]
44-
45-
- name: compress2_strided_batched_medium
46-
category: pre_checkin
47-
function:
48-
compress_strided_batched: *real_precisions_1b_input
49-
matrix_size: *strided_batched_medium_matrix_size_range
50-
transA_transB: *transA_transB_range
51-
alpha_beta: *alpha_beta_range
52-
batch_count: [ 3 ]
53-
sparse_b: [ true, false]
44+
func_version: [1, 2]
5445

5546
- name: compress_strided_batched_medium_stride_zero
5647
category: nightly

clients/gtest/prune_gtest.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (c) 2022-2023 Advanced Micro Devices, Inc.
5+
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -126,6 +126,9 @@ namespace
126126

127127
if(strstr(arg.function, "_strided_batched") != nullptr)
128128
name << '_' << (arg.sparse_b ? arg.stride_b : arg.stride_a);
129+
130+
if(arg.func_version > 1)
131+
name << "_v" << arg.func_version;
129132
}
130133
return std::move(name);
131134
}

clients/gtest/prune_gtest.yaml

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -352,18 +352,7 @@ Tests:
352352
transA_transB: *transA_transB_range
353353
prune_algo: [ 0, 1 ]
354354
sparse_b: [ true, false]
355-
356-
- name: prune2_256_8_16
357-
category: pre_checkin
358-
function:
359-
prune: *real_precisions_2b
360-
M: 256
361-
N: 8
362-
K: 16
363-
alpha_beta: *alpha_beta_range
364-
transA_transB: *transA_transB_range
365-
prune_algo: [ 0, 1 ]
366-
sparse_b: [ true, false]
355+
func_version: [1, 2]
367356

368357
- name: prune_16_256_8
369358
category: pre_checkin
@@ -376,18 +365,7 @@ Tests:
376365
transA_transB: *transA_transB_range
377366
prune_algo: [ 0, 1 ]
378367
sparse_b: [ true, false]
379-
380-
- name: prune2_16_256_8
381-
category: pre_checkin
382-
function:
383-
prune: *real_precisions_2b
384-
M: 16
385-
N: 256
386-
K: 8
387-
alpha_beta: *alpha_beta_range
388-
transA_transB: *transA_transB_range
389-
prune_algo: [ 0, 1 ]
390-
sparse_b: [ true, false]
368+
func_version: [1, 2]
391369

392370
- name: prune_8_16_256
393371
category: pre_checkin

clients/gtest/prune_strided_batched_gtest.yaml

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,7 @@ Tests:
4444
batch_count: [ 3 ]
4545
prune_algo: [ 0, 1 ]
4646
sparse_b: [true, false]
47-
48-
- name: prune2_strided_batched_medium
49-
category: pre_checkin
50-
function:
51-
prune_strided_batched: *real_precisions_2b
52-
matrix_size: *strided_batched_medium_matrix_size_range
53-
transA_transB: *transA_transB_range
54-
alpha_beta: *alpha_beta_range
55-
batch_count: [ 3 ]
56-
prune_algo: [ 0, 1 ]
57-
sparse_b: [ true, false]
47+
func_version: [1, 2]
5848

5949
- name: prune_strided_batched_medium_alt
6050
category: pre_checkin

clients/gtest/prune_strided_batched_gtest_1b.yaml

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,7 @@ Tests:
4444
batch_count: [ 3 ]
4545
prune_algo: [ 0, 1 ]
4646
sparse_b: [ true, false]
47-
48-
- name: prune2_strided_batched_medium
49-
category: pre_checkin
50-
function:
51-
prune_strided_batched: *real_precisions_1b_input
52-
matrix_size: *strided_batched_medium_matrix_size_range
53-
transA_transB: *transA_transB_range
54-
alpha_beta: *alpha_beta_range
55-
batch_count: [ 3 ]
56-
prune_algo: [ 0, 1 ]
57-
sparse_b: [ true, false]
47+
func_version: [1, 2]
5848

5949
- name: prune_strided_batched_medium_stride_zero
6050
category: nightly

clients/include/hipsparselt_arguments.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
*
33
* MIT License
44
*
5-
* Copyright (c) 2022-2023 Advanced Micro Devices, Inc.
5+
* Copyright (c) 2022-2024 Advanced Micro Devices, Inc.
66
*
77
* Permission is hereby granted, free of charge, to any person obtaining a copy
88
* of this software and associated documentation files (the "Software"), to deal
@@ -124,6 +124,7 @@ struct Arguments
124124
int32_t search_iters;
125125

126126
bool sparse_b;
127+
int func_version;
127128
/*************************************************************************
128129
* End Of Arguments *
129130
*************************************************************************/
@@ -184,8 +185,9 @@ struct Arguments
184185
OPER(c_noalias_d) SEP \
185186
OPER(HMM) SEP \
186187
OPER(search) SEP \
187-
OPER(search_iters) SEP \
188-
OPER(sparse_b) SEP
188+
OPER(search_iters) SEP \
189+
OPER(sparse_b) SEP \
190+
OPER(func_version) SEP
189191

190192
// clang-format on
191193

clients/include/hipsparselt_common.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ Arguments:
124124
- search: c_bool
125125
- search_iters: c_int32
126126
- sparse_b: c_bool
127+
- func_version: c_int32
127128

128129
# These named dictionary lists [ {dict1}, {dict2}, etc. ] supply subsets of
129130
# test arguments in a structured way. The dictionaries are applied to the test
@@ -199,3 +200,4 @@ Defaults:
199200
search: false
200201
search_iters: 10
201202
sparse_b: false
203+
func_version: 1

clients/include/spmm/testing_compress.hpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -311,10 +311,6 @@ template <typename Ti,
311311
hipsparselt_batch_type btype = hipsparselt_batch_type::none>
312312
void testing_compress(const Arguments& arg)
313313
{
314-
int run_version = 1;
315-
if(strstr(arg.name, "compress2") != nullptr)
316-
run_version = 2;
317-
318314
hipsparseOperation_t transA = char_to_hipsparselt_operation(arg.transA);
319315
hipsparseOperation_t transB = char_to_hipsparselt_operation(arg.transB);
320316

@@ -477,13 +473,13 @@ void testing_compress(const Arguments& arg)
477473

478474
hipsparseLtMatmulGetWorkspace(handle, plan, &workspace_size);
479475

480-
if(run_version == 1)
476+
if(arg.func_version == 1)
481477
{
482478
EXPECT_HIPSPARSE_STATUS(
483479
hipsparseLtSpMMACompressedSize(handle, plan, &compressed_size, &compress_buffer_size),
484480
HIPSPARSE_STATUS_SUCCESS);
485481
}
486-
else if(run_version == 2)
482+
else if(arg.func_version == 2)
487483
{
488484
EXPECT_HIPSPARSE_STATUS(
489485
hipsparseLtSpMMACompressedSize2(
@@ -552,14 +548,14 @@ void testing_compress(const Arguments& arg)
552548
// copy data from CPU to device
553549
CHECK_HIP_ERROR(dT.transfer_from(hT));
554550

555-
if(run_version == 1)
551+
if(arg.func_version == 1)
556552
{
557553
EXPECT_HIPSPARSE_STATUS(
558554
hipsparseLtSpMMAPrune(
559555
handle, matmul, dT, dT, hipsparseLtPruneAlg_t(arg.prune_algo), stream),
560556
HIPSPARSE_STATUS_SUCCESS);
561557
}
562-
else if(run_version == 2)
558+
else if(arg.func_version == 2)
563559
{
564560
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPrune2(handle,
565561
arg.sparse_b ? matB : matA,
@@ -619,11 +615,11 @@ void testing_compress(const Arguments& arg)
619615
CHECK_HIP_ERROR(hipStreamSynchronize(stream));
620616
CHECK_HIP_ERROR(hT_pruned.transfer_from(dT));
621617

622-
if(run_version == 1)
618+
if(arg.func_version == 1)
623619
EXPECT_HIPSPARSE_STATUS(
624620
hipsparseLtSpMMACompress(handle, plan, dT, dT_compressd, dT_compressBuffer, stream),
625621
HIPSPARSE_STATUS_SUCCESS);
626-
else if(run_version == 2)
622+
else if(arg.func_version == 2)
627623
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMACompress2(handle,
628624
arg.sparse_b ? matB : matA,
629625
!arg.sparse_b,

clients/include/spmm/testing_prune.hpp

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,6 @@ template <typename Ti,
440440
hipsparselt_batch_type btype = hipsparselt_batch_type::none>
441441
void testing_prune(const Arguments& arg)
442442
{
443-
int run_version = 1;
444-
445-
if(strstr(arg.name, "prune2") != nullptr)
446-
run_version = 2;
447-
448443
hipsparseLtPruneAlg_t prune_algo = hipsparseLtPruneAlg_t(arg.prune_algo);
449444

450445
constexpr bool do_batched = (btype == hipsparselt_batch_type::batched);
@@ -676,11 +671,11 @@ void testing_prune(const Arguments& arg)
676671

677672
if(arg.unit_check || arg.norm_check)
678673
{
679-
if(run_version == 1)
674+
if(arg.func_version == 1)
680675
EXPECT_HIPSPARSE_STATUS(
681676
hipsparseLtSpMMAPrune(handle, matmul, dT, dT_pruned, prune_algo, stream),
682677
HIPSPARSE_STATUS_SUCCESS);
683-
else if(run_version == 2)
678+
else if(arg.func_version == 2)
684679
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPrune2(handle,
685680
arg.sparse_b ? matB : matA,
686681
!arg.sparse_b,
@@ -699,11 +694,11 @@ void testing_prune(const Arguments& arg)
699694
device_vector<int> d_valid(1, 1, HMM);
700695
int h_valid = 0;
701696
//check the pruned matrix is sparisty 50 or not.
702-
if(run_version == 1)
697+
if(arg.func_version == 1)
703698
EXPECT_HIPSPARSE_STATUS(
704699
hipsparseLtSpMMAPruneCheck(handle, matmul, dT_pruned, d_valid, stream),
705700
HIPSPARSE_STATUS_SUCCESS);
706-
else if(run_version == 2)
701+
else if(arg.func_version == 2)
707702
EXPECT_HIPSPARSE_STATUS(hipsparseLtSpMMAPruneCheck2(handle,
708703
arg.sparse_b ? matB : matA,
709704
!arg.sparse_b,

0 commit comments

Comments
 (0)