Skip to content

Instantly share code, notes, and snippets.

@abikoushi
Last active February 10, 2022 14:42
Show Gist options
  • Save abikoushi/ce0cfaeb3156fda5b7126df03a334bdc to your computer and use it in GitHub Desktop.
Save abikoushi/ce0cfaeb3156fda5b7126df03a334bdc to your computer and use it in GitHub Desktop.
Restricted Mean Survival Time
using Distributions
using SpecialFunctions
using QuadGK
using Plots
#integral erfc((log(x) - a)/b) dx = x erfc((log(x) - a)/b) - e^(a + b^2/4) erf((2 a + b^2 - 2 log(x))/(2 b)) + 定数
function lbcdf(d::LogNormal, x::Real)
mu, sigma = params(d)
return 0.5*(1+x*erfc((log(x) - mu)/(sigma*sqrt(2)))*exp(-(mu + (sigma^2)/2)) - erf((mu + sigma^2 - log(x))/(sigma*sqrt(2))))
end
function lbcdf(d::Weibull, x::Real)
shp, scl = params(d)
return gamma_inc(inv(shp),(max(x, 0) / scl) ^ shp,0)[1]
end
function lbcdf(d::Gamma, x::Real)
shp, scl = params(d)
return gamma_inc(shp+1, x/scl, 0)[1]+(x/scl)*gamma_inc(shp, x/scl, 0)[2]/shp
end
function logmean(d::LogNormal)
mu, sigma = params(d)
return mu + (sigma^2)/2
end
function logmean(d::Weibull)
shp, scl = params(d)
return log(scl) - loggamma(1.0+inv(shp))
end
function logmean(d::Gamma)
shp, scl= params(d)
return log(shp) - inv(scl)
end
function rmst(d::UnivariateDistribution, x::Real)
return mean(d)*lbcdf(d,x)
end
function logrmst(d::UnivariateDistribution, x::Real)
return logmean(d) + log(lbcdf(d,x))
end
v,_ = quadgk(x -> ccdf(Gamma(2.0,0.5),x),0.0,3.0)
v2 = rmst(Gamma(2.0,0.5),3.0)
isapprox(v,v2)
v,_ = quadgk(x -> ccdf(Weibull(2.0,2.0),x),0.0,3.0)
v2 = rmst(Weibull(2.0,2.0),3.0)
isapprox(v,v2)
v,_ = quadgk(x -> ccdf(LogNormal(2.0,2.0),x),0.0,3.0)
v2 = rmst(LogNormal(2.0,2.0),3.0)
isapprox(v,v2)
plot(x-> lbcdf(Weibull(2.0,2.0),x),0,4,legend=false)
plot(x-> rmst(Weibull(2.0,2.0),x),0,4,legend=false)
Plots.abline!(0,mean(Weibull(2.0,2.0)))
plot(x-> lbcdf(Gamma(2.0,0.5),x),0,4,legend=false)
plot(x-> rmst(Gamma(2.0,0.5),x),0,4,legend=false)
Plots.abline!(0,mean(Gamma(2.0,0.5)))
plot(x-> lbcdf(LogNormal(2.0,1.2),x),0,1000,legend=false)
plot(x-> rmst(LogNormal(2.0,1.2),x),0,1000,legend=false)
Plots.abline!(0,mean(LogNormal(2.0,1.2)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment