Skip to content

Commit a9b73c0

Browse files
authored
Merge pull request #24 from person142/fv
MAINT: clean up the implementation of `fv`
2 parents b4e7a75 + c26860f commit a9b73c0

File tree

2 files changed

+92
-37
lines changed

2 files changed

+92
-37
lines changed

numpy_financial/_financial.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,11 +124,27 @@ def fv(rate, nper, pmt, pv, when='end'):
124124
125125
"""
126126
when = _convert_when(when)
127-
(rate, nper, pmt, pv, when) = map(np.asarray, [rate, nper, pmt, pv, when])
128-
temp = (1+rate)**nper
129-
fact = np.where(rate == 0, nper,
130-
(1 + rate*when)*(temp - 1)/rate)
131-
return -(pv*temp + pmt*fact)
127+
rate, nper, pmt, pv, when = np.broadcast_arrays(rate, nper, pmt, pv, when)
128+
129+
fv_array = np.empty_like(rate)
130+
zero = rate == 0
131+
nonzero = ~zero
132+
133+
fv_array[zero] = -(pv[zero] + pmt[zero] * nper[zero])
134+
135+
rate_nonzero = rate[nonzero]
136+
temp = (1 + rate_nonzero)**nper[nonzero]
137+
fv_array[nonzero] = (
138+
- pv[nonzero] * temp
139+
- pmt[nonzero] * (1 + rate_nonzero * when[nonzero]) / rate_nonzero
140+
* (temp - 1)
141+
)
142+
143+
if np.ndim(fv_array) == 0:
144+
# Follow the ufunc convention of returning scalars for scalar
145+
# and 0d array inputs.
146+
return fv_array.item(0)
147+
return fv_array
132148

133149

134150
def pmt(rate, nper, pv, fv=0, when='end'):

numpy_financial/tests/test_financial.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,6 @@ def test_pv_decimal(self):
5555
Decimal('0')),
5656
Decimal('-127128.1709461939327295222005'))
5757

58-
def test_fv(self):
59-
assert_equal(npf.fv(0.075, 20, -2000, 0, 0), 86609.362673042924)
60-
61-
def test_fv_decimal(self):
62-
assert_equal(npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'),
63-
0, 0),
64-
Decimal('86609.36267304300040536731624'))
65-
6658
def test_pmt(self):
6759
res = npf.pmt(0.08 / 12, 5 * 12, 15000)
6860
tgt = -304.145914
@@ -197,15 +189,6 @@ def test_when(self):
197189
assert_equal(npf.pv(0.07, 20, 12000, 0, 0),
198190
npf.pv(0.07, 20, 12000, 0, 'end'))
199191

200-
# begin
201-
assert_equal(npf.fv(0.075, 20, -2000, 0, 1),
202-
npf.fv(0.075, 20, -2000, 0, 'begin'))
203-
# end
204-
assert_equal(npf.fv(0.075, 20, -2000, 0),
205-
npf.fv(0.075, 20, -2000, 0, 'end'))
206-
assert_equal(npf.fv(0.075, 20, -2000, 0, 0),
207-
npf.fv(0.075, 20, -2000, 0, 'end'))
208-
209192
# begin
210193
assert_equal(npf.pmt(0.08 / 12, 5 * 12, 15000., 0, 1),
211194
npf.pmt(0.08 / 12, 5 * 12, 15000., 0, 'begin'))
@@ -267,21 +250,6 @@ def test_decimal_with_when(self):
267250
npf.pv(Decimal('0.07'), Decimal('20'), Decimal('12000'),
268251
Decimal('0'), 'end'))
269252

270-
# begin
271-
assert_equal(npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'),
272-
Decimal('0'), Decimal('1')),
273-
npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'),
274-
Decimal('0'), 'begin'))
275-
# end
276-
assert_equal(npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'),
277-
Decimal('0')),
278-
npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'),
279-
Decimal('0'), 'end'))
280-
assert_equal(npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'),
281-
Decimal('0'), Decimal('0')),
282-
npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'),
283-
Decimal('0'), 'end'))
284-
285253
# begin
286254
assert_equal(npf.pmt(Decimal('0.08') / Decimal('12'),
287255
Decimal('5') * Decimal('12'), Decimal('15000.'),
@@ -475,3 +443,74 @@ def test_decimal_broadcasting(self):
475443
Decimal('2000')
476444
)
477445
assert_almost_equal(result, desired, decimal=4)
446+
447+
448+
class TestFv:
449+
def test_float(self):
450+
assert_allclose(
451+
npf.fv(0.075, 20, -2000, 0, 0),
452+
86609.362673042924,
453+
rtol=1e-10,
454+
)
455+
456+
def test_decimal(self):
457+
assert_almost_equal(
458+
npf.fv(Decimal('0.075'), Decimal('20'), Decimal('-2000'), 0, 0),
459+
Decimal('86609.36267304300040536731624'),
460+
decimal=10,
461+
)
462+
463+
@pytest.mark.parametrize('when', [1, 'begin'])
464+
def test_when_is_begin_float(self, when):
465+
assert_allclose(
466+
npf.fv(0.075, 20, -2000, 0, when),
467+
93105.064874, # Computed using Google Sheet's FV
468+
rtol=1e-10,
469+
)
470+
471+
@pytest.mark.parametrize('when', [Decimal('1'), 'begin'])
472+
def test_when_is_begin_decimal(self, when):
473+
result = npf.fv(
474+
Decimal('0.075'),
475+
Decimal('20'),
476+
Decimal('-2000'),
477+
Decimal('0'),
478+
when,
479+
)
480+
assert_almost_equal(result, Decimal('93105.064874'), decimal=5)
481+
482+
@pytest.mark.parametrize('when', [None, 0, 'end'])
483+
def test_when_is_end_float(self, when):
484+
args = (0.075, 20, -2000, 0)
485+
result = npf.fv(*args) if when is None else npf.fv(*args, when)
486+
assert_allclose(
487+
result,
488+
86609.362673, # Computed using Google Sheet's FV
489+
rtol=1e-10,
490+
)
491+
492+
@pytest.mark.parametrize('when', [None, Decimal('0'), 'end'])
493+
def test_when_is_end_decimal(self, when):
494+
args = (
495+
Decimal('0.075'),
496+
Decimal('20'),
497+
Decimal('-2000'),
498+
Decimal('0'),
499+
)
500+
result = npf.fv(*args) if when is None else npf.fv(*args, when)
501+
assert_almost_equal(result, Decimal('86609.362673'), decimal=5)
502+
503+
def test_broadcast(self):
504+
result = npf.fv([[0.1], [0.2]], 5, 100, 0, [0, 1])
505+
# All values computed using Google Sheet's FV
506+
desired = [[-610.510000, -671.561000],
507+
[-744.160000, -892.992000]]
508+
assert_allclose(result, desired, rtol=1e-10)
509+
510+
def test_some_rates_zero(self):
511+
# Check that the logical indexing is working correctly.
512+
assert_allclose(
513+
npf.fv([0, 0.1], 5, 100, 0),
514+
[-500, -610.51], # Computed using Google Sheet's FV
515+
rtol=1e-10,
516+
)

0 commit comments

Comments
 (0)