8
8
from numpy .testing import (
9
9
assert_ ,
10
10
assert_allclose ,
11
- assert_almost_equal ,
12
11
assert_equal ,
13
12
assert_raises ,
14
13
)
15
14
16
15
import numpy_financial as npf
17
16
17
+ def assert_decimal_close (actual , expected , tol = Decimal ("1e-7" )):
18
+ # Check if both actual and expected are iterable (like arrays)
19
+ if hasattr (actual , "__iter__" ) and hasattr (expected , "__iter__" ):
20
+ for a , e in zip (actual , expected ):
21
+ assert abs (a - e ) <= tol
22
+ else :
23
+ # For single value comparisons
24
+ assert abs (actual - expected ) <= tol
18
25
19
26
class TestFinancial (object ):
20
27
def test_when (self ):
@@ -91,7 +98,7 @@ def test_decimal_with_when(self):
91
98
92
99
class TestPV :
93
100
def test_pv (self ):
94
- assert_almost_equal (npf .pv (0.07 , 20 , 12000 , 0 ), - 127128.17 , 2 )
101
+ assert_allclose (npf .pv (0.07 , 20 , 12000 , 0 ), - 127128.17 , rtol = 1e- 2 )
95
102
96
103
def test_pv_decimal (self ):
97
104
assert_equal (npf .pv (Decimal ('0.07' ), Decimal ('20' ), Decimal ('12000' ),
@@ -101,7 +108,7 @@ def test_pv_decimal(self):
101
108
102
109
class TestRate :
103
110
def test_rate (self ):
104
- assert_almost_equal (npf .rate (10 , 0 , - 3500 , 10000 ), 0.1107 , 4 )
111
+ assert_allclose (npf .rate (10 , 0 , - 3500 , 10000 ), 0.1107 , rtol = 1e- 4 )
105
112
106
113
@pytest .mark .parametrize ('number_type' , [Decimal , float ])
107
114
@pytest .mark .parametrize ('when' , [0 , 1 , 'end' , 'begin' ])
@@ -163,9 +170,9 @@ def test_rate_maximum_iterations_exception_array(self):
163
170
164
171
class TestNpv :
165
172
def test_npv (self ):
166
- assert_almost_equal (
173
+ assert_allclose (
167
174
npf .npv (0.05 , [- 15000.0 , 1500.0 , 2500.0 , 3500.0 , 4500.0 , 6000.0 ]),
168
- 122.89 , 2 )
175
+ 122.89 , rtol = 1e- 2 )
169
176
170
177
def test_npv_decimal (self ):
171
178
assert_equal (
@@ -282,7 +289,8 @@ def test_mirr(self, values, finance_rate, reinvest_rate, expected):
282
289
283
290
if expected :
284
291
decimal_part_len = len (str (expected ).split ('.' )[1 ])
285
- assert_almost_equal (result , expected , decimal_part_len )
292
+ difference = 10 ** - decimal_part_len
293
+ assert_allclose (result , expected , atol = difference )
286
294
else :
287
295
assert_ (numpy .isnan (result ))
288
296
@@ -319,7 +327,7 @@ def test_mirr_decimal(self, number_type, args, expected):
319
327
)
320
328
321
329
if expected is not numpy .nan :
322
- assert_almost_equal (result , number_type (expected ), 15 )
330
+ assert_decimal_close (result , number_type (expected ), tol = 1e- 15 )
323
331
else :
324
332
assert numpy .isnan (result )
325
333
@@ -357,7 +365,7 @@ def test_no_interest(self):
357
365
assert_ (npf .nper (0 , - 100 , 1000 ) == 10 )
358
366
359
367
def test_broadcast (self ):
360
- assert_almost_equal (npf .nper (0.075 , - 2000 , 0 , 100000. , [0 , 1 ]),
368
+ assert_allclose (npf .nper (0.075 , - 2000 , 0 , 100000. , [0 , 1 ]),
361
369
[21.5449442 , 20.76156441 ], 4 )
362
370
363
371
@@ -409,10 +417,10 @@ def test_when_is_begin_decimal(self, when):
409
417
Decimal ('0' ),
410
418
when
411
419
)
412
- assert_almost_equal (
420
+ assert_decimal_close (
413
421
result ,
414
422
Decimal ('-302.131703' ), # Computed using Google Sheet's PPMT
415
- decimal = 5 ,
423
+ tol = 1e- 5 ,
416
424
)
417
425
418
426
@pytest .mark .parametrize ('when' , [None , Decimal ('0' ), 'end' ])
@@ -425,10 +433,10 @@ def test_when_is_end_decimal(self, when):
425
433
Decimal ('0' )
426
434
)
427
435
result = npf .ppmt (* args ) if when is None else npf .ppmt (* args , when )
428
- assert_almost_equal (
436
+ assert_decimal_close (
429
437
result ,
430
438
Decimal ('-204.145914' ), # Computed using Google Sheet's PPMT
431
- decimal = 5 ,
439
+ tol = 1e- 5 ,
432
440
)
433
441
434
442
@pytest .mark .parametrize ('args' , [
@@ -481,7 +489,7 @@ def test_broadcast_decimal(self, when, desired):
481
489
Decimal ('0' )
482
490
)
483
491
result = npf .ppmt (* args ) if when is None else npf .ppmt (* args , when )
484
- assert_almost_equal (result , desired , decimal = 8 )
492
+ assert_decimal_close (result , desired , tol = 1e- 8 )
485
493
486
494
487
495
class TestIpmt :
@@ -532,7 +540,7 @@ def test_when_is_end_decimal(self, when):
532
540
Decimal ('0' )
533
541
)
534
542
result = npf .ipmt (* args ) if when is None else npf .ipmt (* args , when )
535
- assert_almost_equal (result , desired , decimal = 5 )
543
+ assert_decimal_close (result , desired , tol = 1e- 5 )
536
544
537
545
@pytest .mark .parametrize ('per, desired' , [
538
546
(0 , numpy .nan ),
@@ -576,7 +584,7 @@ def test_decimal_broadcasting(self):
576
584
Decimal ('24' ),
577
585
Decimal ('2000' )
578
586
)
579
- assert_almost_equal (result , desired , decimal = 4 )
587
+ assert_decimal_close (result , desired , tol = 1e- 4 )
580
588
581
589
def test_0d_inputs (self ):
582
590
args = (0.1 / 12 , 1 , 24 , 2000 )
@@ -596,10 +604,10 @@ def test_float(self):
596
604
)
597
605
598
606
def test_decimal (self ):
599
- assert_almost_equal (
607
+ assert_decimal_close (
600
608
npf .fv (Decimal ('0.075' ), Decimal ('20' ), Decimal ('-2000' ), 0 , 0 ),
601
609
Decimal ('86609.36267304300040536731624' ),
602
- decimal = 10 ,
610
+ tol = 1e- 10 ,
603
611
)
604
612
605
613
@pytest .mark .parametrize ('when' , [1 , 'begin' ])
@@ -619,7 +627,7 @@ def test_when_is_begin_decimal(self, when):
619
627
Decimal ('0' ),
620
628
when ,
621
629
)
622
- assert_almost_equal (result , Decimal ('93105.064874' ), decimal = 5 )
630
+ assert_decimal_close (result , Decimal ('93105.064874' ), tol = 5 )
623
631
624
632
@pytest .mark .parametrize ('when' , [None , 0 , 'end' ])
625
633
def test_when_is_end_float (self , when ):
@@ -640,7 +648,7 @@ def test_when_is_end_decimal(self, when):
640
648
Decimal ('0' ),
641
649
)
642
650
result = npf .fv (* args ) if when is None else npf .fv (* args , when )
643
- assert_almost_equal (result , Decimal ('86609.362673' ), decimal = 5 )
651
+ assert_decimal_close (result , Decimal ('86609.362673' ), tol = 5 )
644
652
645
653
def test_broadcast (self ):
646
654
result = npf .fv ([[0.1 ], [0.2 ]], 5 , 100 , 0 , [0 , 1 ])
@@ -682,13 +690,13 @@ def test_npv_irr_congruence(self):
682
690
([- 5 , 10.5 , 1 , - 8 , 1 ], 0.0886 ),
683
691
])
684
692
def test_basic_values (self , v , desired ):
685
- assert_almost_equal (npf .irr (v ), desired , decimal = 2 )
693
+ assert_allclose (npf .irr (v ), desired , rtol = 1e- 2 )
686
694
687
695
def test_trailing_zeros (self ):
688
- assert_almost_equal (
696
+ assert_allclose (
689
697
npf .irr ([- 5 , 10.5 , 1 , - 8 , 1 , 0 , 0 , 0 ]),
690
698
0.0886 ,
691
- decimal = 2 ,
699
+ rtol = 1e- 2 ,
692
700
)
693
701
694
702
@pytest .mark .parametrize ('v' , [
@@ -750,12 +758,12 @@ def test_gh_39(self):
750
758
- 16259.479306324123 , - 23596.31953754941 , - 30933.159768774713 ,
751
759
- 38270.0 , - 45606.8402312253 , - 52943.680462450604 ,
752
760
- 60280.520693675906 , - 67617.36092490121 ])
753
- assert_almost_equal (npf .irr (cashflows ), 0.12 )
761
+ assert_allclose (npf .irr (cashflows ), 0.12 )
754
762
755
763
def test_gh_44 (self ):
756
764
# "true" value as calculated by Google sheets
757
765
cf = [- 1678.87 , 771.96 , 1814.05 , 3520.30 , 3552.95 , 3584.99 , 4789.91 , - 1 ]
758
- assert_almost_equal (npf .irr (cf ), 1.00426 , 4 )
766
+ assert_allclose (npf .irr (cf ), 1.00426 , rtol = 1e- 4 )
759
767
760
768
def test_irr_no_real_solution_exception (self ):
761
769
# Test that if there is no solution because all the cashflows
0 commit comments