Skip to content

Commit 7d1c6ad

Browse files
committed
softmax for ranks from 1 to 4
1 parent 2ff7029 commit 7d1c6ad

File tree

1 file changed

+112
-7
lines changed

1 file changed

+112
-7
lines changed

src/stdlib_math_activations.fypp

Lines changed: 112 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -104,14 +104,20 @@ module stdlib_math_activations
104104

105105
interface Softmax
106106
#:for rk, rt in REAL_KINDS_TYPES
107-
module procedure :: softmax_${rk}$
107+
module procedure :: Softmax_r1_${rk}$
108+
module procedure :: Softmax_r2_${rk}$
109+
module procedure :: Softmax_r3_${rk}$
110+
module procedure :: Softmax_r4_${rk}$
108111
#:endfor
109112
end interface
110113
public :: softmax
111114

112115
interface Softmax_grad
113116
#:for rk, rt in REAL_KINDS_TYPES
114-
module procedure :: Softmax_grad_${rk}$
117+
module procedure :: Softmax_grad_r1_${rk}$
118+
module procedure :: Softmax_grad_r2_${rk}$
119+
module procedure :: Softmax_grad_r3_${rk}$
120+
module procedure :: Softmax_grad_r4_${rk}$
115121
#:endfor
116122
end interface
117123
public :: Softmax_grad
@@ -315,19 +321,118 @@ end function
315321
! Softmax
316322
!==================================================
317323
#:for rk, rt in REAL_KINDS_TYPES
318-
pure function Softmax_${rk}$( x ) result( y )
324+
pure function Softmax_r1_${rk}$( x ) result( y )
319325
${rt}$, intent(in) :: x(:)
320326
${rt}$ :: y(size(x))
321327

322-
y(:) = exp(x(:) - maxval(x(:)) )
323-
y(:) = y(:) / sum(y(:))
328+
y = exp(x - maxval(x))
329+
y = y / sum(y)
324330
end function
325331

326-
pure function Softmax_grad_${rk}$( x ) result( y )
332+
pure function Softmax_r2_${rk}$( x , dim ) result( y )
333+
${rt}$, intent(in) :: x(:,:)
334+
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
335+
336+
integer, intent(in), optional :: dim
337+
integer :: dim_, j
338+
339+
dim_ = 1; if(present(dim)) dim_ = dim
340+
341+
if(dim_==1)then
342+
do j = 1, size(x,dim=2)
343+
y(:,j) = Softmax( x(:,j) )
344+
end do
345+
else
346+
do j = 1, size(x,dim=1)
347+
y(j,:) = Softmax( x(j,:) )
348+
end do
349+
end if
350+
end function
351+
352+
pure function Softmax_r3_${rk}$( x , dim ) result( y )
353+
${rt}$, intent(in) :: x(:,:,:)
354+
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
355+
356+
integer, intent(in), optional :: dim
357+
integer :: dim_, j
358+
359+
dim_ = 1; if(present(dim)) dim_ = dim
360+
361+
if(dim_<=2)then
362+
do j = 1, size(x,dim=3)
363+
y(:,:,j) = Softmax( x(:,:,j) , dim = dim_ )
364+
end do
365+
else
366+
do j = 1, size(x,dim=1)
367+
y(j,:,:) = Softmax( x(j,:,:) , dim = 2 )
368+
end do
369+
end if
370+
end function
371+
372+
pure function Softmax_r4_${rk}$( x , dim ) result( y )
373+
${rt}$, intent(in) :: x(:,:,:,:)
374+
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
375+
376+
integer, intent(in), optional :: dim
377+
integer :: dim_, j
378+
379+
dim_ = 1; if(present(dim)) dim_ = dim
380+
381+
if(dim_<=3)then
382+
do j = 1, size(x,dim=4)
383+
y(:,:,:,j) = Softmax( x(:,:,:,j) , dim = dim_ )
384+
end do
385+
else
386+
do j = 1, size(x,dim=1)
387+
y(j,:,:,:) = Softmax( x(j,:,:,:) , dim = 3 )
388+
end do
389+
end if
390+
end function
391+
392+
pure function Softmax_grad_r1_${rk}$( x ) result( y )
327393
${rt}$, intent(in) :: x(:)
328394
${rt}$ :: y(size(x))
329395

330-
y = softmax_${rk}$(x)
396+
y = Softmax(x)
397+
y = y * (1_${rk}$ - y)
398+
end function
399+
400+
pure function Softmax_grad_r2_${rk}$( x , dim ) result( y )
401+
${rt}$, intent(in) :: x(:,:)
402+
${rt}$ :: y(size(x,dim=1),size(x,dim=2))
403+
404+
integer, intent(in), optional :: dim
405+
integer :: dim_
406+
407+
dim_ = 1; if(present(dim)) dim_ = dim
408+
409+
y = Softmax(x,dim_)
410+
y = y * (1_${rk}$ - y)
411+
end function
412+
413+
pure function Softmax_grad_r3_${rk}$( x , dim ) result( y )
414+
${rt}$, intent(in) :: x(:,:,:)
415+
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3))
416+
417+
integer, intent(in), optional :: dim
418+
integer :: dim_
419+
420+
dim_ = 1; if(present(dim)) dim_ = dim
421+
422+
y = Softmax(x,dim_)
423+
y = y * (1_${rk}$ - y)
424+
end function
425+
426+
pure function Softmax_grad_r4_${rk}$( x , dim ) result( y )
427+
${rt}$, intent(in) :: x(:,:,:)
428+
${rt}$ :: y(size(x,dim=1),size(x,dim=2),size(x,dim=3),size(x,dim=4))
429+
430+
integer, intent(in), optional :: dim
431+
integer :: dim_
432+
433+
dim_ = 1; if(present(dim)) dim_ = dim
434+
435+
y = Softmax(x,dim_)
331436
y = y * (1_${rk}$ - y)
332437
end function
333438

0 commit comments

Comments
 (0)