Skip to content

Commit 8b69811

Browse files
author
Aryan Chowdhury
committed
Replaced all assert_almost_equals
1 parent 1656639 commit 8b69811

File tree

1 file changed

+32
-24
lines changed

1 file changed

+32
-24
lines changed

tests/test_financial.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,20 @@
88
from numpy.testing import (
99
assert_,
1010
assert_allclose,
11-
assert_almost_equal,
1211
assert_equal,
1312
assert_raises,
1413
)
1514

1615
import numpy_financial as npf
1716

17+
def assert_decimal_close(actual, expected, tol=Decimal("1e-7")):
18+
# Check if both actual and expected are iterable (like arrays)
19+
if hasattr(actual, "__iter__") and hasattr(expected, "__iter__"):
20+
for a, e in zip(actual, expected):
21+
assert abs(a - e) <= tol
22+
else:
23+
# For single value comparisons
24+
assert abs(actual - expected) <= tol
1825

1926
class TestFinancial(object):
2027
def test_when(self):
@@ -91,7 +98,7 @@ def test_decimal_with_when(self):
9198

9299
class TestPV:
93100
def test_pv(self):
94-
assert_almost_equal(npf.pv(0.07, 20, 12000, 0), -127128.17, 2)
101+
assert_allclose(npf.pv(0.07, 20, 12000, 0), -127128.17, rtol=1e-2)
95102

96103
def test_pv_decimal(self):
97104
assert_equal(npf.pv(Decimal('0.07'), Decimal('20'), Decimal('12000'),
@@ -101,7 +108,7 @@ def test_pv_decimal(self):
101108

102109
class TestRate:
103110
def test_rate(self):
104-
assert_almost_equal(npf.rate(10, 0, -3500, 10000), 0.1107, 4)
111+
assert_allclose(npf.rate(10, 0, -3500, 10000), 0.1107, rtol=1e-4)
105112

106113
@pytest.mark.parametrize('number_type', [Decimal, float])
107114
@pytest.mark.parametrize('when', [0, 1, 'end', 'begin'])
@@ -163,9 +170,9 @@ def test_rate_maximum_iterations_exception_array(self):
163170

164171
class TestNpv:
165172
def test_npv(self):
166-
assert_almost_equal(
173+
assert_allclose(
167174
npf.npv(0.05, [-15000.0, 1500.0, 2500.0, 3500.0, 4500.0, 6000.0]),
168-
122.89, 2)
175+
122.89, rtol=1e-2)
169176

170177
def test_npv_decimal(self):
171178
assert_equal(
@@ -282,7 +289,8 @@ def test_mirr(self, values, finance_rate, reinvest_rate, expected):
282289

283290
if expected:
284291
decimal_part_len = len(str(expected).split('.')[1])
285-
assert_almost_equal(result, expected, decimal_part_len)
292+
difference = 10 ** -decimal_part_len
293+
assert_allclose(result, expected, atol=difference)
286294
else:
287295
assert_(numpy.isnan(result))
288296

@@ -319,7 +327,7 @@ def test_mirr_decimal(self, number_type, args, expected):
319327
)
320328

321329
if expected is not numpy.nan:
322-
assert_almost_equal(result, number_type(expected), 15)
330+
assert_decimal_close(result, number_type(expected), tol=1e-15)
323331
else:
324332
assert numpy.isnan(result)
325333

@@ -357,7 +365,7 @@ def test_no_interest(self):
357365
assert_(npf.nper(0, -100, 1000) == 10)
358366

359367
def test_broadcast(self):
360-
assert_almost_equal(npf.nper(0.075, -2000, 0, 100000., [0, 1]),
368+
assert_allclose(npf.nper(0.075, -2000, 0, 100000., [0, 1]),
361369
[21.5449442, 20.76156441], 4)
362370

363371

@@ -409,10 +417,10 @@ def test_when_is_begin_decimal(self, when):
409417
Decimal('0'),
410418
when
411419
)
412-
assert_almost_equal(
420+
assert_decimal_close(
413421
result,
414422
Decimal('-302.131703'), # Computed using Google Sheet's PPMT
415-
decimal=5,
423+
tol = 1e-5,
416424
)
417425

418426
@pytest.mark.parametrize('when', [None, Decimal('0'), 'end'])
@@ -425,10 +433,10 @@ def test_when_is_end_decimal(self, when):
425433
Decimal('0')
426434
)
427435
result = npf.ppmt(*args) if when is None else npf.ppmt(*args, when)
428-
assert_almost_equal(
436+
assert_decimal_close(
429437
result,
430438
Decimal('-204.145914'), # Computed using Google Sheet's PPMT
431-
decimal=5,
439+
tol=1e-5,
432440
)
433441

434442
@pytest.mark.parametrize('args', [
@@ -481,7 +489,7 @@ def test_broadcast_decimal(self, when, desired):
481489
Decimal('0')
482490
)
483491
result = npf.ppmt(*args) if when is None else npf.ppmt(*args, when)
484-
assert_almost_equal(result, desired, decimal=8)
492+
assert_decimal_close(result, desired, tol=1e-8)
485493

486494

487495
class TestIpmt:
@@ -532,7 +540,7 @@ def test_when_is_end_decimal(self, when):
532540
Decimal('0')
533541
)
534542
result = npf.ipmt(*args) if when is None else npf.ipmt(*args, when)
535-
assert_almost_equal(result, desired, decimal=5)
543+
assert_decimal_close(result, desired, tol=1e-5)
536544

537545
@pytest.mark.parametrize('per, desired', [
538546
(0, numpy.nan),
@@ -576,7 +584,7 @@ def test_decimal_broadcasting(self):
576584
Decimal('24'),
577585
Decimal('2000')
578586
)
579-
assert_almost_equal(result, desired, decimal=4)
587+
assert_decimal_close(result, desired, tol=1e-4)
580588

581589
def test_0d_inputs(self):
582590
args = (0.1 / 12, 1, 24, 2000)
@@ -596,10 +604,10 @@ def test_float(self):
596604
)
597605

598606
def test_decimal(self):
599-
assert_almost_equal(
607+
assert_decimal_close(
600608
npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'), 0, 0),
601609
Decimal('86609.36267304300040536731624'),
602-
decimal=10,
610+
tol=1e-10,
603611
)
604612

605613
@pytest.mark.parametrize('when', [1, 'begin'])
@@ -619,7 +627,7 @@ def test_when_is_begin_decimal(self, when):
619627
Decimal('0'),
620628
when,
621629
)
622-
assert_almost_equal(result, Decimal('93105.064874'), decimal=5)
630+
assert_decimal_close(result, Decimal('93105.064874'), tol=5)
623631

624632
@pytest.mark.parametrize('when', [None, 0, 'end'])
625633
def test_when_is_end_float(self, when):
@@ -640,7 +648,7 @@ def test_when_is_end_decimal(self, when):
640648
Decimal('0'),
641649
)
642650
result = npf.fv(*args) if when is None else npf.fv(*args, when)
643-
assert_almost_equal(result, Decimal('86609.362673'), decimal=5)
651+
assert_decimal_close(result, Decimal('86609.362673'), tol=5)
644652

645653
def test_broadcast(self):
646654
result = npf.fv([[0.1], [0.2]], 5, 100, 0, [0, 1])
@@ -682,13 +690,13 @@ def test_npv_irr_congruence(self):
682690
([-5, 10.5, 1, -8, 1], 0.0886),
683691
])
684692
def test_basic_values(self, v, desired):
685-
assert_almost_equal(npf.irr(v), desired, decimal=2)
693+
assert_allclose(npf.irr(v), desired, rtol=1e-2)
686694

687695
def test_trailing_zeros(self):
688-
assert_almost_equal(
696+
assert_allclose(
689697
npf.irr([-5, 10.5, 1, -8, 1, 0, 0, 0]),
690698
0.0886,
691-
decimal=2,
699+
rtol=1e-2,
692700
)
693701

694702
@pytest.mark.parametrize('v', [
@@ -750,12 +758,12 @@ def test_gh_39(self):
750758
-16259.479306324123, -23596.31953754941, -30933.159768774713,
751759
-38270.0, -45606.8402312253, -52943.680462450604,
752760
-60280.520693675906, -67617.36092490121])
753-
assert_almost_equal(npf.irr(cashflows), 0.12)
761+
assert_allclose(npf.irr(cashflows), 0.12)
754762

755763
def test_gh_44(self):
756764
# "true" value as calculated by Google sheets
757765
cf = [-1678.87, 771.96, 1814.05, 3520.30, 3552.95, 3584.99, 4789.91, -1]
758-
assert_almost_equal(npf.irr(cf), 1.00426, 4)
766+
assert_allclose(npf.irr(cf), 1.00426, rtol=1e-4)
759767

760768
def test_irr_no_real_solution_exception(self):
761769
# Test that if there is no solution because all the cashflows

0 commit comments

Comments
 (0)