@@ -104,14 +104,20 @@ module stdlib_math_activations
104
104
105
105
interface Softmax
106
106
#:for rk, rt in REAL_KINDS_TYPES
107
- module procedure :: softmax_${rk}$
107
+ module procedure :: Softmax_r1_${rk}$
108
+ module procedure :: Softmax_r2_${rk}$
109
+ module procedure :: Softmax_r3_${rk}$
110
+ module procedure :: Softmax_r4_${rk}$
108
111
#:endfor
109
112
end interface
110
113
public :: softmax
111
114
112
115
interface Softmax_grad
113
116
#:for rk, rt in REAL_KINDS_TYPES
114
- module procedure :: Softmax_grad_${rk}$
117
+ module procedure :: Softmax_grad_r1_${rk}$
118
+ module procedure :: Softmax_grad_r2_${rk}$
119
+ module procedure :: Softmax_grad_r3_${rk}$
120
+ module procedure :: Softmax_grad_r4_${rk}$
115
121
#:endfor
116
122
end interface
117
123
public :: Softmax_grad
@@ -315,19 +321,118 @@ end function
315
321
! Softmax
316
322
!==================================================
317
323
#:for rk, rt in REAL_KINDS_TYPES
318
- pure function Softmax_ ${rk}$( x ) result( y )
324
+ pure function Softmax_r1_ ${rk}$( x ) result( y )
319
325
${rt}$, intent(in) :: x(:)
320
326
${rt}$ :: y(size(x))
321
327
322
- y(:) = exp(x(:) - maxval(x(:)) )
323
- y(:) = y(:) / sum(y(:) )
328
+ y = exp(x - maxval(x) )
329
+ y = y / sum(y)
324
330
end function
325
331
326
- pure function Softmax_grad_${rk}$( x ) result( y )
332
+ pure function Softmax_r2_${rk}$( x , dim ) result( y )
333
+ ${rt}$, intent(in) :: x(:,:)
334
+ ${rt}$ :: y(size(x,dim=1),size(x,dim=2))
335
+
336
+ integer, intent(in), optional :: dim
337
+ integer :: dim_, j
338
+
339
+ dim_ = 1; if(present(dim)) dim_ = dim
340
+
341
+ if(dim_==1)then
342
+ do j = 1, size(x,dim=2)
343
+ y(:,j) = Softmax( x(:,j) )
344
+ end do
345
+ else
346
+ do j = 1, size(x,dim=1)
347
+ y(j,:) = Softmax( x(j,:) )
348
+ end do
349
+ end if
350
+ end function
351
+
352
+ pure function Softmax_r3_${rk}$( x , dim ) result( y )
353
+ ${rt}$, intent(in) :: x(:,:,:)
354
+ ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
355
+
356
+ integer, intent(in), optional :: dim
357
+ integer :: dim_, j
358
+
359
+ dim_ = 1; if(present(dim)) dim_ = dim
360
+
361
+ if(dim_<=2)then
362
+ do j = 1, size(x,dim=3)
363
+ y(:,:,j) = Softmax( x(:,:,j) , dim = dim_ )
364
+ end do
365
+ else
366
+ do j = 1, size(x,dim=1)
367
+ y(j,:,:) = Softmax( x(j,:,:) , dim = 2 )
368
+ end do
369
+ end if
370
+ end function
371
+
372
+ pure function Softmax_r4_${rk}$( x , dim ) result( y )
373
+ ${rt}$, intent(in) :: x(:,:,:,:)
374
+ ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
375
+
376
+ integer, intent(in), optional :: dim
377
+ integer :: dim_, j
378
+
379
+ dim_ = 1; if(present(dim)) dim_ = dim
380
+
381
+ if(dim_<=3)then
382
+ do j = 1, size(x,dim=4)
383
+ y(:,:,:,j) = Softmax( x(:,:,:,j) , dim = dim_ )
384
+ end do
385
+ else
386
+ do j = 1, size(x,dim=1)
387
+ y(j,:,:,:) = Softmax( x(j,:,:,:) , dim = 3 )
388
+ end do
389
+ end if
390
+ end function
391
+
392
+ pure function Softmax_grad_r1_${rk}$( x ) result( y )
327
393
${rt}$, intent(in) :: x(:)
328
394
${rt}$ :: y(size(x))
329
395
330
- y = softmax_${rk}$(x)
396
+ y = Softmax(x)
397
+ y = y * (1_${rk}$ - y)
398
+ end function
399
+
400
+ pure function Softmax_grad_r2_${rk}$( x , dim ) result( y )
401
+ ${rt}$, intent(in) :: x(:,:)
402
+ ${rt}$ :: y(size(x,dim=1),size(x,dim=2))
403
+
404
+ integer, intent(in), optional :: dim
405
+ integer :: dim_
406
+
407
+ dim_ = 1; if(present(dim)) dim_ = dim
408
+
409
+ y = Softmax(x,dim_)
410
+ y = y * (1_${rk}$ - y)
411
+ end function
412
+
413
+ pure function Softmax_grad_r3_${rk}$( x , dim ) result( y )
414
+ ${rt}$, intent(in) :: x(:,:,:)
415
+ ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
416
+
417
+ integer, intent(in), optional :: dim
418
+ integer :: dim_
419
+
420
+ dim_ = 1; if(present(dim)) dim_ = dim
421
+
422
+ y = Softmax(x,dim_)
423
+ y = y * (1_${rk}$ - y)
424
+ end function
425
+
426
+ pure function Softmax_grad_r4_${rk}$( x , dim ) result( y )
427
+ ${rt}$, intent(in) :: x(:,:,:)
428
+ ${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
429
+
430
+ integer, intent(in), optional :: dim
431
+ integer :: dim_
432
+
433
+ dim_ = 1; if(present(dim)) dim_ = dim
434
+
435
+ y = Softmax(x,dim_)
331
436
y = y * (1_${rk}$ - y)
332
437
end function
333
438
0 commit comments