Skip to content

Instantly share code, notes, and snippets.

@RHDZMOTA
Created April 27, 2020 06:13
Show Gist options
  • Select an option

  • Save RHDZMOTA/18c981df78859bf39392541c3825b340 to your computer and use it in GitHub Desktop.

Select an option

Save RHDZMOTA/18c981df78859bf39392541c3825b340 to your computer and use it in GitHub Desktop.
Internal rate of return using JAX.
import jax.numpy as np
def irr(cashflows):
res = np.roots(cashflows[::-1])
mask = (res.imag == 0) & (res.real > 0)
# Filter out imaginary component.
if not mask.any():
return np.nan
res = res[mask].real
# Return the solution closest to zero.
rate = 1 / res - 1
if rate.size == 1:
return rate.item()
return rate[np.argmin(np.abs(rate))].item()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment