Skip to content

Commit e7cb7bc

Browse files
add montecarlo integrate api (#932)
* add montecarlo integrate api * Update ppsci/experimental/math_module.py * Update ppsci/experimental/math_module.py * Update ppsci/experimental/math_module.py * Update ppsci/experimental/math_module.py * Update ppsci/experimental/math_module.py * fix code format error --------- Co-authored-by: HydrogenSulfate <490868991@qq.com>
1 parent 53c4ede commit e7cb7bc

File tree

4 files changed

+240
-0
lines changed

4 files changed

+240
-0
lines changed

docs/zh/api/experimental.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,6 @@
1515
- fractional_diff
1616
- gaussian_integrate
1717
- trapezoid_integrate
18+
- montecarlo_integrate
1819
show_root_heading: true
1920
heading_level: 3

ppsci/experimental/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ppsci.experimental.math_module import bessel_i1e
2323
from ppsci.experimental.math_module import fractional_diff
2424
from ppsci.experimental.math_module import gaussian_integrate
25+
from ppsci.experimental.math_module import montecarlo_integrate
2526
from ppsci.experimental.math_module import trapezoid_integrate
2627

2728
__all__ = [
@@ -32,4 +33,5 @@
3233
"fractional_diff",
3334
"gaussian_integrate",
3435
"trapezoid_integrate",
36+
"montecarlo_integrate",
3537
]

ppsci/experimental/math_module.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from typing import Any
1919
from typing import Callable
2020
from typing import List
21+
from typing import Optional
2122
from typing import Tuple
23+
from typing import Union
2224

2325
import numpy as np
2426
import paddle
@@ -462,3 +464,183 @@ def trapezoid_integrate(
462464
return paddle.cumulative_trapezoid(y, x, dx, axis)
463465
else:
464466
raise ValueError(f'mode should be "sum" or "cumsum", but got {mode}')
467+
468+
469+
def montecarlo_integrate(
470+
fn: Callable,
471+
dim: int,
472+
N: int = 1000,
473+
integration_domain: Union[List[List[float]], paddle.Tensor] = None,
474+
seed: int = None,
475+
):
476+
"""Integrates the passed function on the passed domain using vanilla Monte
477+
Carlo Integration.
478+
479+
Args:
480+
fn (Callable): The function to integrate over.
481+
dim (int): Dimensionality of the function's domain over which to
482+
integrate.
483+
N (Optional[int]): Number of sample points to use for the integration.
484+
Defaults to 1000.
485+
integration_domain (Union[List[List[float]], paddle.Tensor]): Integration
486+
domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim.
487+
seed (Optional[int]): Random number generation seed to the sampling
488+
point creation, only set if provided. Defaults to None.
489+
490+
Raises:
491+
ValueError: If len(integration_domain) != dim
492+
493+
Returns:
494+
Integral value
495+
496+
Examples:
497+
>>> import paddle
498+
>>> import ppsci
499+
500+
>>> _ = paddle.seed(1024)
501+
>>> # The function we want to integrate, in this example
502+
>>> # f(x0,x1) = sin(x0) + e^x1 for x0=[0,1] and x1=[-1,1]
503+
>>> # Note that the function needs to support multiple evaluations at once (first
504+
>>> # dimension of x here)
505+
>>> # Expected result here is ~3.2698
506+
>>> def some_function(x):
507+
... return paddle.sin(x[:, 0]) + paddle.exp(x[:, 1])
508+
509+
>>> # Compute the function integral by sampling 10000 points over domain
510+
>>> integral_value = ppsci.experimental.montecarlo_integrate(
511+
... some_function,
512+
... dim=2,
513+
... N=10000,
514+
... integration_domain=[[0, 1], [-1, 1]],
515+
... )
516+
517+
>>> print(integral_value)
518+
Tensor(shape=[], dtype=float32, place=Place(gpu:0), stop_gradient=True,
519+
3.25152588)
520+
"""
521+
522+
@expand_func_values_and_squeeze_integral
523+
def calculate_result(function_values, integration_domain):
524+
"""Calculate an integral result from the function evaluations
525+
526+
Args:
527+
function_values (paddle.Tensor): Output of the integrand
528+
integration_domain (paddle.Tensor): Integration domain
529+
530+
Returns:
531+
Quadrature result
532+
"""
533+
scales = integration_domain[:, 1] - integration_domain[:, 0]
534+
volume = paddle.prod(scales)
535+
536+
# Integral = V / N * sum(func values)
537+
N = function_values.shape[0]
538+
integral = volume * paddle.sum(function_values, axis=0) / N
539+
return integral
540+
541+
def calculate_sample_points(
542+
N: int, integration_domain: paddle.Tensor, seed: Optional[int] = None
543+
):
544+
"""Calculate random points for the integrand evaluation.
545+
546+
Args:
547+
N (int): Number of points
548+
integration_domain (paddle.Tensor): Integration domain.
549+
seed (int, optional): Random number generation seed for the sampling point creation, only set if provided. Defaults to None.
550+
Returns:
551+
Sample points.
552+
"""
553+
dim = integration_domain.shape[0]
554+
domain_starts = integration_domain[:, 0]
555+
domain_sizes = integration_domain[:, 1] - domain_starts
556+
# Scale and translate random numbers via broadcasting
557+
return (
558+
paddle.uniform(
559+
shape=[N, dim],
560+
dtype=domain_sizes.dtype,
561+
min=0.0,
562+
max=1.0,
563+
seed=seed or 0,
564+
)
565+
* domain_sizes
566+
+ domain_starts
567+
)
568+
569+
if dim is not None:
570+
if dim < 1:
571+
raise ValueError("Dimension needs to be 1 or larger.")
572+
if N is not None:
573+
if N < 1 or type(N) is not int:
574+
raise ValueError("N has to be a positive integer.")
575+
576+
integration_domain = _setup_integration_domain(dim, integration_domain)
577+
sample_points = calculate_sample_points(N, integration_domain, seed)
578+
function_values, _ = _evaluate_integrand(fn, sample_points)
579+
return calculate_result(function_values, integration_domain)
580+
581+
582+
def _setup_integration_domain(
583+
dim: int, integration_domain: Union[List[List[float]], paddle.Tensor]
584+
) -> paddle.Tensor:
585+
"""Sets up the integration domain if unspecified by the user.
586+
Args:
587+
dim (int): Dimensionality of the integration domain.
588+
integration_domain (List or Tensor): Integration domain, e.g. [[-1,1],[0,1]]. Defaults to [-1,1]^dim.
589+
590+
Returns:
591+
Integration domain.
592+
"""
593+
# If no integration_domain is specified, create [-1,1]^d bounds
594+
if integration_domain is None:
595+
integration_domain = [[-1.0, 1.0]] * dim
596+
597+
integration_domain = [[float(b) for b in bounds] for bounds in integration_domain]
598+
599+
integration_domain = paddle.to_tensor(integration_domain)
600+
601+
if tuple(integration_domain.shape) != (dim, 2):
602+
raise ValueError(
603+
"The integration domain has an unexpected shape. "
604+
f"Expected {(dim, 2)}, got {integration_domain.shape}"
605+
)
606+
return integration_domain
607+
608+
609+
def _evaluate_integrand(fn, points, weights=None, args=None):
610+
"""Evaluate the integrand function at the passed points.
611+
612+
Args:
613+
fn (Callable): Integrand function.
614+
points (paddle.Tensor): Integration points.
615+
weights (Optional[paddle.Tensor]): Integration weights. Defaults to None.
616+
args (Optional[List, Tuple]): Any arguments required by the function. Defaults to None.
617+
618+
Returns:
619+
padlde.Tensor: Integrand function output.
620+
int: Number of evaluated points.
621+
"""
622+
num_points = points.shape[0]
623+
624+
if args is None:
625+
args = ()
626+
627+
result = fn(points, *args)
628+
num_results = result.shape[0]
629+
if num_results != num_points:
630+
raise ValueError(
631+
f"The passed function was given {num_points} points but only returned {num_results} value(s)."
632+
f"Please ensure that your function is vectorized, i.e. can be called with multiple evaluation points at once. It should return a tensor "
633+
f"where first dimension matches length of passed elements. "
634+
)
635+
636+
if weights is not None:
637+
if (
638+
len(result.shape) > 1
639+
): # if the the integrand is multi-dimensional, we need to reshape/repeat weights so they can be broadcast in the *=
640+
integrand_shape = paddle.to_tensor(result.shape[1:])
641+
weights = paddle.repeat(
642+
paddle.unsqueeze(weights, axis=1), paddle.prod(integrand_shape)
643+
).reshape((weights.shape[0], *(integrand_shape)))
644+
result *= weights
645+
646+
return result, num_points
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Callable
16+
from typing import List
17+
18+
import numpy as np
19+
import paddle
20+
import pytest
21+
22+
import ppsci
23+
24+
paddle.seed(1024)
25+
26+
27+
@pytest.mark.parametrize(
28+
"fn, dim, N, integration_domains, expected",
29+
[
30+
(
31+
lambda x: paddle.sin(x[:, 0]) + paddle.exp(x[:, 1]),
32+
2,
33+
10000,
34+
[[0, 1], [-1, 1]],
35+
3.25152588,
36+
)
37+
],
38+
)
39+
def test_montecarlo_integrate(
40+
fn: Callable,
41+
dim: int,
42+
N: int,
43+
integration_domains: List[List[float]],
44+
expected: float,
45+
):
46+
assert np.allclose(
47+
ppsci.experimental.montecarlo_integrate(
48+
fn, dim, N, integration_domains
49+
).numpy(),
50+
expected,
51+
)
52+
53+
54+
if __name__ == "__main__":
55+
pytest.main()

0 commit comments

Comments
 (0)