1+ #include < cstdio>
2+ #include < cuda_runtime.h>
3+ #include < random>
4+
5+ __global__ void saxpy (int n, float a, float *x, float *y){
6+ // threadIdx.x: thread index within the block
7+ // blockIdx.x: block index within the grid
8+ // blockDim.x: number of threads per block
9+ // gridDim.x: number of blocks in the grid
10+
11+ // global_id: unique index for each thread in the entire grid
12+ int global_id = threadIdx .x + blockDim .x * blockIdx .x ;
13+
14+ // Example: gridDim.x = 2, blockDim.x = 4
15+
16+ // Block 0: threadIdx.x = [0,1,2,3] → global_id = [0,1,2,3]
17+ // Block 1: threadIdx.x = [0,1,2,3] → global_id = [4,5,6,7]
18+
19+ // stride: total number of threads in the grid
20+ int stride = blockDim .x * gridDim .x ;
21+
22+ // Each thread processes multiple elements, striding by the total number of threads
23+ // Striding ensures all elements are processed even if n > total threads
24+ for (int i=global_id; i < n; i += stride)
25+ {
26+ y[i] = a * x[i] + y[i];
27+ }
28+ }
29+
30+ int main () {
31+ // Set up data
32+ const int N = 100 ;
33+ float alpha = 3 .14f ;
34+ float *h_x, *h_y;
35+ float *d_x, *d_y;
36+ size_t size = N * sizeof (float );
37+
38+ // Allocate device memory
39+ cudaMalloc (&d_x, size);
40+ cudaMalloc (&d_y, size);
41+
42+ // Initialize host data
43+ h_x = (float *)malloc (size);
44+ h_y = (float *)malloc (size);
45+
46+ for (int i = 0 ; i < N; i++) {
47+ h_x[i] = rand () / (float )RAND_MAX;
48+ h_y[i] = rand () / (float )RAND_MAX;
49+ }
50+
51+ // Copy data to device
52+ cudaMemcpy (d_x, h_x, size, cudaMemcpyHostToDevice);
53+ cudaMemcpy (d_y, h_y, size, cudaMemcpyHostToDevice);
54+
55+ // Define block size (number of threads per block)
56+ int blockSize = 4 ;
57+
58+ // Calculate number of blocks needed
59+ int numBlocks = (N + blockSize - 1 ) / blockSize;
60+
61+ // Launch kernel
62+ saxpy<<<numBlocks, blockSize>>> (N, alpha, d_x, d_y);
63+ cudaDeviceSynchronize ();
64+
65+ // Copy result back to host
66+ cudaMemcpy (h_y, d_y, size, cudaMemcpyDeviceToHost);
67+ for (int i = 0 ; i < N; i++) {
68+ printf (" h_y[%d] = %f\n " , i, h_y[i]);
69+ }
70+
71+ // Clean up
72+ free (h_x);
73+ free (h_y);
74+ cudaFree (d_x);
75+ cudaFree (d_y);
76+
77+ return 0 ;
78+ }
0 commit comments