Skip to content

Commit 2d7b5ed

Browse files
authored
Merge pull request #26 from person142/ipmt-cleanups
MAINT: small cleanups to `ipmt`
2 parents 9b2cb63 + 9ea473c commit 2d7b5ed

File tree

2 files changed

+46
-27
lines changed

2 files changed

+46
-27
lines changed

numpy_financial/_financial.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -417,21 +417,26 @@ def ipmt(rate, per, nper, pv, fv=0, when='end'):
417417
pv, fv, when)
418418

419419
total_pmt = pmt(rate, nper, pv, fv, when)
420-
ipmt = np.array(_rbl(rate, per, total_pmt, pv, when) * rate)
420+
ipmt_array = np.array(_rbl(rate, per, total_pmt, pv, when) * rate)
421421

422422
# Payments start at the first period, so payments before that
423423
# don't make any sense.
424-
ipmt[per < 1] = _value_like(ipmt, np.nan)
424+
ipmt_array[per < 1] = _value_like(ipmt_array, np.nan)
425425
# If payments occur at the beginning of a period and this is the
426426
# first period, then no interest has accrued.
427427
per1_and_begin = (when == 1) & (per == 1)
428-
ipmt[per1_and_begin] = _value_like(ipmt, 0)
428+
ipmt_array[per1_and_begin] = _value_like(ipmt_array, 0)
429429
# If paying at the beginning we need to discount by one period.
430430
per_gt_1_and_begin = (when == 1) & (per > 1)
431-
ipmt[per_gt_1_and_begin] = (
432-
ipmt[per_gt_1_and_begin] / (1 + rate[per_gt_1_and_begin])
431+
ipmt_array[per_gt_1_and_begin] = (
432+
ipmt_array[per_gt_1_and_begin] / (1 + rate[per_gt_1_and_begin])
433433
)
434-
return ipmt
434+
435+
if np.ndim(ipmt_array) == 0:
436+
# Follow the ufunc convention of returning scalars for scalar
437+
# and 0d array inputs.
438+
return ipmt_array.item(0)
439+
return ipmt_array
435440

436441

437442
def _rbl(rate, per, pmt, pv, when):

numpy_financial/tests/test_financial.py

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -265,27 +265,6 @@ def test_decimal_with_when(self):
265265
Decimal('60'), Decimal('55000'), Decimal('0'),
266266
'end'))
267267

268-
# begin
269-
assert_equal(npf.ipmt(Decimal('0.1') / Decimal('12'), Decimal('1'),
270-
Decimal('24'), Decimal('2000'),
271-
Decimal('0'), Decimal('1')).flat[0],
272-
npf.ipmt(Decimal('0.1') / Decimal('12'), Decimal('1'),
273-
Decimal('24'), Decimal('2000'),
274-
Decimal('0'), 'begin').flat[0])
275-
# end
276-
assert_equal(npf.ipmt(Decimal('0.1') / Decimal('12'), Decimal('1'),
277-
Decimal('24'), Decimal('2000'),
278-
Decimal('0')).flat[0],
279-
npf.ipmt(Decimal('0.1') / Decimal('12'), Decimal('1'),
280-
Decimal('24'), Decimal('2000'),
281-
Decimal('0'), 'end').flat[0])
282-
assert_equal(npf.ipmt(Decimal('0.1') / Decimal('12'), Decimal('1'),
283-
Decimal('24'), Decimal('2000'),
284-
Decimal('0'), Decimal('0')).flat[0],
285-
npf.ipmt(Decimal('0.1') / Decimal('12'), Decimal('1'),
286-
Decimal('24'), Decimal('2000'),
287-
Decimal('0'), 'end').flat[0])
288-
289268
def test_broadcast(self):
290269
assert_almost_equal(npf.nper(0.075, -2000, 0, 100000., [0, 1]),
291270
[21.5449442, 20.76156441], 4)
@@ -374,6 +353,33 @@ def test_when_is_end(self, when):
374353
result = npf.ipmt(0.1 / 12, 1, 24, 2000, 0, when)
375354
assert_allclose(result, -16.666667, rtol=1e-6)
376355

356+
357+
@pytest.mark.parametrize('when', [Decimal('1'), 'begin'])
358+
def test_when_is_begin_decimal(self, when):
359+
result = npf.ipmt(
360+
Decimal('0.1') / Decimal('12'),
361+
Decimal('1'),
362+
Decimal('24'),
363+
Decimal('2000'),
364+
Decimal('0'),
365+
when,
366+
)
367+
assert result == 0
368+
369+
@pytest.mark.parametrize('when', [None, Decimal('0'), 'end'])
370+
def test_when_is_end_decimal(self, when):
371+
# Computed using Google Sheet's IPMT
372+
desired = Decimal('-16.666667')
373+
args = (
374+
Decimal('0.1') / Decimal('12'),
375+
Decimal('1'),
376+
Decimal('24'),
377+
Decimal('2000'),
378+
Decimal('0')
379+
)
380+
result = npf.ipmt(*args) if when is None else npf.ipmt(*args, when)
381+
assert_almost_equal(result, desired, decimal=5)
382+
377383
@pytest.mark.parametrize('per, desired', [
378384
(0, numpy.nan),
379385
(1, 0),
@@ -418,6 +424,14 @@ def test_decimal_broadcasting(self):
418424
)
419425
assert_almost_equal(result, desired, decimal=4)
420426

427+
def test_0d_inputs(self):
428+
args = (0.1 / 12, 1, 24, 2000)
429+
# Scalar inputs should return a scalar.
430+
assert numpy.isscalar(npf.ipmt(*args))
431+
args = (numpy.array(args[0]),) + args[1:]
432+
# 0d array inputs should return a scalar.
433+
assert numpy.isscalar(npf.ipmt(*args))
434+
421435

422436
class TestFv:
423437
def test_float(self):

0 commit comments

Comments
 (0)