Skip to content

Commit b4e7a75

Browse files
authored
Merge pull request #22 from person142/ipmt
BUG: make `ipmt` return `nan` for `per < 1`
2 parents 60cc842 + 91ba6b8 commit b4e7a75

File tree

2 files changed

+112
-46
lines changed

2 files changed

+112
-46
lines changed

numpy_financial/_financial.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,14 @@ def nper(rate, pmt, pv, fv=0, when='end'):
302302
return nper_array
303303

304304

305+
def _value_like(arr, value):
306+
entry = arr.item(0)
307+
if isinstance(entry, Decimal):
308+
return Decimal(value)
309+
else:
310+
return np.array(value, dtype=arr.dtype).item(0)
311+
312+
305313
def ipmt(rate, per, nper, pv, fv=0, when='end'):
306314
"""
307315
Compute the interest portion of a payment.
@@ -391,13 +399,22 @@ def ipmt(rate, per, nper, pv, fv=0, when='end'):
391399
when = _convert_when(when)
392400
rate, per, nper, pv, fv, when = np.broadcast_arrays(rate, per, nper,
393401
pv, fv, when)
402+
394403
total_pmt = pmt(rate, nper, pv, fv, when)
395-
ipmt = _rbl(rate, per, total_pmt, pv, when)*rate
396-
try:
397-
ipmt = np.where(when == 1, ipmt/(1 + rate), ipmt)
398-
ipmt = np.where(np.logical_and(when == 1, per == 1), 0, ipmt)
399-
except IndexError:
400-
pass
404+
ipmt = np.array(_rbl(rate, per, total_pmt, pv, when) * rate)
405+
406+
# Payments start at the first period, so payments before that
407+
# don't make any sense.
408+
ipmt[per < 1] = _value_like(ipmt, np.nan)
409+
# If payments occur at the beginning of a period and this is the
410+
# first period, then no interest has accrued.
411+
per1_and_begin = (when == 1) & (per == 1)
412+
ipmt[per1_and_begin] = _value_like(ipmt, 0)
413+
# If paying at the beginning we need to discount by one period.
414+
per_gt_1_and_begin = (when == 1) & (per > 1)
415+
ipmt[per_gt_1_and_begin] = (
416+
ipmt[per_gt_1_and_begin] / (1 + rate[per_gt_1_and_begin])
417+
)
401418
return ipmt
402419

403420

numpy_financial/tests/test_financial.py

Lines changed: 89 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from numpy.testing import (
88
assert_, assert_almost_equal, assert_allclose, assert_equal, assert_raises
99
)
10+
import pytest
1011

1112
import numpy_financial as npf
1213

@@ -133,14 +134,6 @@ def raise_error_because_not_equal():
133134
Decimal('10000000000')),
134135
Decimal('-90238044.2322778884413969909'))
135136

136-
def test_ipmt(self):
137-
assert_almost_equal(numpy.round(npf.ipmt(0.1 / 12, 1, 24, 2000), 2),
138-
-16.67)
139-
140-
def test_ipmt_decimal(self):
141-
result = npf.ipmt(Decimal('0.1') / Decimal('12'), 1, 24, 2000)
142-
assert_equal(result.flat[0], Decimal('-16.66666666666666666666666667'))
143-
144137
def test_npv(self):
145138
assert_almost_equal(
146139
npf.npv(0.05, [-15000, 1500, 2500, 3500, 4500, 6000]),
@@ -231,15 +224,6 @@ def test_when(self):
231224
assert_equal(npf.ppmt(0.1 / 12, 1, 60, 55000, 0, 0),
232225
npf.ppmt(0.1 / 12, 1, 60, 55000, 0, 'end'))
233226

234-
# begin
235-
assert_equal(npf.ipmt(0.1 / 12, 1, 24, 2000, 0, 1),
236-
npf.ipmt(0.1 / 12, 1, 24, 2000, 0, 'begin'))
237-
# end
238-
assert_equal(npf.ipmt(0.1 / 12, 1, 24, 2000, 0),
239-
npf.ipmt(0.1 / 12, 1, 24, 2000, 0, 'end'))
240-
assert_equal(npf.ipmt(0.1 / 12, 1, 24, 2000, 0, 0),
241-
npf.ipmt(0.1 / 12, 1, 24, 2000, 0, 'end'))
242-
243227
# begin
244228
assert_equal(npf.nper(0.075, -2000, 0, 100000., 1),
245229
npf.nper(0.075, -2000, 0, 100000., 'begin'))
@@ -364,44 +348,40 @@ def test_broadcast(self):
364348
assert_almost_equal(npf.nper(0.075, -2000, 0, 100000., [0, 1]),
365349
[21.5449442, 20.76156441], 4)
366350

367-
assert_almost_equal(npf.ipmt(0.1 / 12, list(range(5)), 24, 2000),
368-
[-17.29165168, -16.66666667, -16.03647345,
369-
-15.40102862, -14.76028842], 4)
370-
371351
assert_almost_equal(npf.ppmt(0.1 / 12, list(range(5)), 24, 2000),
372-
[-74.998201, -75.62318601, -76.25337923,
352+
[numpy.nan, -75.62318601, -76.25337923,
373353
-76.88882405, -77.52956425], 4)
374354

375355
assert_almost_equal(npf.ppmt(0.1 / 12, list(range(5)), 24, 2000, 0,
376356
[0, 0, 1, 'end', 'begin']),
377-
[-74.998201, -75.62318601, -75.62318601,
357+
[numpy.nan, -75.62318601, -75.62318601,
378358
-76.88882405, -76.88882405], 4)
379359

380360
def test_broadcast_decimal(self):
381361
# Use almost equal because precision is tested in the explicit tests,
382362
# this test is to ensure broadcast with Decimal is not broken.
383-
assert_almost_equal(npf.ipmt(Decimal('0.1') / Decimal('12'),
384-
list(range(5)), Decimal('24'),
385-
Decimal('2000')),
386-
[Decimal('-17.29165168'), Decimal('-16.66666667'),
387-
Decimal('-16.03647345'), Decimal('-15.40102862'),
388-
Decimal('-14.76028842')], 4)
389-
390363
assert_almost_equal(npf.ppmt(Decimal('0.1') / Decimal('12'),
391-
list(range(5)), Decimal('24'),
364+
list(range(1, 5)), Decimal('24'),
392365
Decimal('2000')),
393-
[Decimal('-74.998201'), Decimal('-75.62318601'),
366+
[Decimal('-75.62318601'),
394367
Decimal('-76.25337923'), Decimal('-76.88882405'),
395368
Decimal('-77.52956425')], 4)
396369

397-
assert_almost_equal(npf.ppmt(Decimal('0.1') / Decimal('12'),
398-
list(range(5)), Decimal('24'),
399-
Decimal('2000'), Decimal('0'),
400-
[Decimal('0'), Decimal('0'), Decimal('1'),
401-
'end', 'begin']),
402-
[Decimal('-74.998201'), Decimal('-75.62318601'),
403-
Decimal('-75.62318601'), Decimal('-76.88882405'),
404-
Decimal('-76.88882405')], 4)
370+
result = npf.ppmt(
371+
Decimal('0.1') / Decimal('12'),
372+
list(range(1, 5)),
373+
Decimal('24'),
374+
Decimal('2000'),
375+
Decimal('0'),
376+
[Decimal('0'), Decimal('1'), 'end', 'begin']
377+
)
378+
desired = [
379+
Decimal('-75.62318601'),
380+
Decimal('-75.62318601'),
381+
Decimal('-76.88882405'),
382+
Decimal('-76.88882405')
383+
]
384+
assert_almost_equal(result, desired, decimal=4)
405385

406386

407387
class TestNper:
@@ -426,3 +406,72 @@ def test_infinite_payments(self):
426406

427407
def test_no_interest(self):
428408
assert_(npf.nper(0, -100, 1000) == 10)
409+
410+
411+
class TestIpmt:
412+
def test_float(self):
413+
assert_allclose(
414+
npf.ipmt(0.1 / 12, 1, 24, 2000),
415+
-16.666667, # Computed using Google Sheet's IPMT
416+
rtol=1e-6,
417+
)
418+
419+
def test_decimal(self):
420+
result = npf.ipmt(Decimal('0.1') / Decimal('12'), 1, 24, 2000)
421+
assert result == Decimal('-16.66666666666666666666666667')
422+
423+
@pytest.mark.parametrize('when', [1, 'begin'])
424+
def test_when_is_begin(self, when):
425+
assert npf.ipmt(0.1 / 12, 1, 24, 2000, 0, when) == 0
426+
427+
@pytest.mark.parametrize('when', [None, 0, 'end'])
428+
def test_when_is_end(self, when):
429+
if when is None:
430+
result = npf.ipmt(0.1 / 12, 1, 24, 2000)
431+
else:
432+
result = npf.ipmt(0.1 / 12, 1, 24, 2000, 0, when)
433+
assert_allclose(result, -16.666667, rtol=1e-6)
434+
435+
@pytest.mark.parametrize('per, desired', [
436+
(0, numpy.nan),
437+
(1, 0),
438+
(2, -594.107158),
439+
(3, -592.971592),
440+
])
441+
def test_gh_17(self, per, desired):
442+
# All desired results computed using Google Sheet's IPMT
443+
rate = 0.001988079518355057
444+
result = npf.ipmt(rate, per, 360, 300000, when="begin")
445+
if numpy.isnan(desired):
446+
assert numpy.isnan(result)
447+
else:
448+
assert_allclose(result, desired, rtol=1e-6)
449+
450+
def test_broadcasting(self):
451+
desired = [
452+
numpy.nan,
453+
-16.66666667,
454+
-16.03647345,
455+
-15.40102862,
456+
-14.76028842
457+
]
458+
assert_allclose(
459+
npf.ipmt(0.1 / 12, numpy.arange(5), 24, 2000),
460+
desired,
461+
rtol=1e-6,
462+
)
463+
464+
def test_decimal_broadcasting(self):
465+
desired = [
466+
Decimal('-16.66666667'),
467+
Decimal('-16.03647345'),
468+
Decimal('-15.40102862'),
469+
Decimal('-14.76028842')
470+
]
471+
result = npf.ipmt(
472+
Decimal('0.1') / Decimal('12'),
473+
list(range(1, 5)),
474+
Decimal('24'),
475+
Decimal('2000')
476+
)
477+
assert_almost_equal(result, desired, decimal=4)

0 commit comments

Comments
 (0)