Skip to content

Commit de7cf17

Browse files
authored
【Hackathon 6th No.31】paddle.distribution.Normal/paddle.nn.initializer.Normal/paddle.randn/paddle.standard_normal support complex normal d (#6735)
1 parent dce5c33 commit de7cf17

File tree

4 files changed

+31
-7
lines changed

4 files changed

+31
-7
lines changed

docs/api/paddle/distribution/Normal_cn.rst

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,22 @@ Normal
88
99
正态分布
1010

11-
数学公式
11+
若 `loc` 是实数,概率密度函数为
1212

1313
.. math::
1414
1515
pdf(x; \mu, \sigma) = \frac{1}{Z}e^{\frac {-0.5 (x - \mu)^2} {\sigma^2} }
1616
1717
Z = (2 \pi \sigma^2)^{0.5}
1818
19+
若 `loc` 是复数,概率密度函数为:
20+
21+
.. math::
22+
23+
pdf(x; \mu, \sigma) = \frac{1}{Z}e^{\frac {-(x - \mu)^2} {\sigma^2} }
24+
25+
Z = \pi \sigma^2
26+
1927
上面的数学公式中:
2028

2129
- :math:`loc = \mu`:平均值;
@@ -25,7 +33,7 @@ Normal
2533
参数
2634
::::::::::::
2735

28-
- **loc** (int|float|list|tuple|numpy.ndarray|Tensor) - 正态分布平均值。数据类型为 float32 或 float64
36+
- **loc** (int|float|complex|list|tuple|numpy.ndarray|Tensor) - 正态分布平均值。数据类型为 float32、float64、complex64complex128
2937
- **scale** (int|float|list|tuple|numpy.ndarray|Tensor) - 正态分布标准差。数据类型为 float32 或 float64。
3038
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
3139

@@ -85,12 +93,18 @@ entropy()
8593

8694
信息熵
8795

88-
数学公式
96+
实高斯分布信息熵的数学公式
8997

9098
.. math::
9199
92100
entropy(\sigma) = 0.5 \log (2 \pi e \sigma^2)
93101
102+
复高斯分布信息熵的数学公式:
103+
104+
.. math::
105+
106+
entropy(\sigma) = \log (\pi e \sigma^2) + 1
107+
94108
上面的数学公式中:
95109

96110
:math:`scale = \sigma`:标准差。
@@ -130,7 +144,7 @@ kl_divergence(other)
130144

131145
两个正态分布之间的 KL 散度。
132146

133-
数学公式
147+
实高斯分布 KL 散度的数学公式
134148

135149
.. math::
136150
@@ -140,6 +154,16 @@ kl_divergence(other)
140154
141155
diff = \mu_1 - \mu_0
142156
157+
复高斯分布 KL 散度的数学公式:
158+
159+
.. math::
160+
161+
KL\_divergence(\mu_0, \sigma_0; \mu_1, \sigma_1) = ratio^2 + (\frac{diff}{\sigma_1})^2 - 1 - 2 \ln {ratio}
162+
163+
ratio = \frac{\sigma_0}{\sigma_1}
164+
165+
diff = \mu_1 - \mu_0
166+
143167
上面的数学公式中:
144168

145169
- :math:`loc = \mu_0`:当前正态分布的平均值;

docs/api/paddle/nn/initializer/Normal_cn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Normal
1111
参数
1212
::::::::::::
1313

14-
- **mean** (float,可选) - 正态分布的平均值。默认值为 0。
14+
- **mean** (float|complex,可选) - 正态分布的平均值。默认值为 0。
1515
- **std** (float,可选) - 正态分布的标准差。默认值为 1.0。
1616
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
1717

docs/api/paddle/randn_cn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ randn
1010
参数
1111
::::::::::
1212
- **shape** (list|tuple|Tensor) - 生成的随机 Tensor 的形状。如果 ``shape`` 是 list、tuple,则其中的元素可以是 int,或者是形状为[]且数据类型为 int32、int64 的 0-D Tensor。如果 ``shape`` 是 Tensor,则是数据类型为 int32、int64 的 1-D Tensor。
13-
- **dtype** (str|np.dtype,可选) - 输出 Tensor 的数据类型,支持 float32、float64。当该参数值为 None 时,输出 Tensor 的数据类型为 float32。默认值为 None。
13+
- **dtype** (str|np.dtype,可选) - 输出 Tensor 的数据类型,支持 float32、float64、complex64、complex128。当该参数值为 None 时,输出 Tensor 的数据类型为 float32。默认值为 None。
1414
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
1515

1616
返回

docs/api/paddle/standard_normal_cn.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ standard_normal
1010
参数
1111
::::::::::
1212
- **shape** (list|tuple|Tensor) - 生成的随机 Tensor 的形状。如果 ``shape`` 是 list、tuple,则其中的元素可以是 int,或者是形状为[]且数据类型为 int32、int64 的 0-D Tensor。如果 ``shape`` 是 Tensor,则是数据类型为 int32、int64 的 1-D Tensor。
13-
- **dtype** (str|np.dtype,可选) - 输出 Tensor 的数据类型,支持 float32、float64。当该参数值为 None 时,输出 Tensor 的数据类型为 float32。默认值为 None。
13+
- **dtype** (str|np.dtype,可选) - 输出 Tensor 的数据类型,支持 float32、float64、complex64、complex128。当该参数值为 None 时,输出 Tensor 的数据类型为 float32。默认值为 None。
1414
- **name** (str,可选) - 具体用法请参见 :ref:`api_guide_Name`,一般无需设置,默认值为 None。
1515

1616
返回

0 commit comments

Comments
 (0)