Created
December 30, 2020 07:23
-
-
Save iwiwi/10fb477eceaff0d36cdacf9a268db780 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use crate::*; | |
mod normal_distribution { | |
const S2PI: f64 = 2.50662827463100050242E0; | |
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c | |
const P0: [f64; 5] = [ | |
-5.99633501014107895267E1, | |
9.80010754185999661536E1, | |
-5.66762857469070293439E1, | |
1.39312609387279679503E1, | |
-1.23916583867381258016E0, | |
]; | |
const Q0: [f64; 8] = [ | |
/* 1.00000000000000000000E0, */ | |
1.95448858338141759834E0, | |
4.67627912898881538453E0, | |
8.63602421390890590575E1, | |
-2.25462687854119370527E2, | |
2.00260212380060660359E2, | |
-8.20372256168333339912E1, | |
1.59056225126211695515E1, | |
-1.18331621121330003142E0, | |
]; | |
const P1: [f64; 9] = [ | |
4.05544892305962419923E0, | |
3.15251094599893866154E1, | |
5.71628192246421288162E1, | |
4.40805073893200834700E1, | |
1.46849561928858024014E1, | |
2.18663306850790267539E0, | |
-1.40256079171354495875E-1, | |
-3.50424626827848203418E-2, | |
-8.57456785154685413611E-4, | |
]; | |
const Q1: [f64; 8] = [ | |
/* 1.00000000000000000000E0, */ | |
1.57799883256466749731E1, | |
4.53907635128879210584E1, | |
4.13172038254672030440E1, | |
1.50425385692907503408E1, | |
2.50464946208309415979E0, | |
-1.42182922854787788574E-1, | |
-3.80806407691578277194E-2, | |
-9.33259480895457427372E-4, | |
]; | |
const P2: [f64; 9] = [ | |
3.23774891776946035970E0, | |
6.91522889068984211695E0, | |
3.93881025292474443415E0, | |
1.33303460815807542389E0, | |
2.01485389549179081538E-1, | |
1.23716634817820021358E-2, | |
3.01581553508235416007E-4, | |
2.65806974686737550832E-6, | |
6.23974539184983293730E-9, | |
]; | |
const Q2: [f64; 8] = [ | |
/* 1.00000000000000000000E0, */ | |
6.02427039364742014255E0, | |
3.67983563856160859403E0, | |
1.37702099489081330271E0, | |
2.16236993594496635890E-1, | |
1.34204006088543189037E-2, | |
3.28014464682127739104E-4, | |
2.89247864745380683936E-6, | |
6.79019408009981274425E-9, | |
]; | |
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L67 | |
fn polevl(x: f64, coef: &[f64]) -> f64 { | |
let mut ans = 0.0; | |
for c in coef { | |
ans = ans * x + *c; | |
} | |
ans | |
} | |
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L90 | |
fn p1evl(x: f64, coef: &[f64]) -> f64 { | |
let mut ans = 1.0; | |
for c in coef { | |
ans = ans * x + *c; | |
} | |
ans | |
} | |
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c#L134 | |
pub fn ppf(y0: f64) -> f64 { | |
dbg!(y0); | |
assert!(0.0 <= y0 && y0 <= 1.0); | |
let y; | |
let code; | |
if y0 > (1.0 - 0.13533528323661269189) { | |
y = 1.0 - y0; | |
code = 0; | |
} else { | |
y = y0; | |
code = 1; | |
} | |
if y > 0.13533528323661269189 { | |
let y = y - 0.5; | |
let y2 = y * y; | |
let x = y + y * (y2 * polevl(y2, &P0) / p1evl(y2, &Q0)); | |
let x = x * S2PI; | |
return x; | |
} | |
let x = (-2.0 * y.ln()).sqrt(); | |
let x0 = x - x.ln() / x; | |
let z = 1.0 / x; | |
let x1; | |
if x < 8.0 { | |
x1 = z * polevl(z, &P1) / p1evl(z, &Q1); | |
} else { | |
x1 = z * polevl(z, &P2) / p1evl(z, &Q2); | |
} | |
let mut x = x0 - x1; | |
if code != 0 { | |
x = -x; | |
} | |
x | |
} | |
} | |
const BOUNDS_THRESHOLD: f64 = 1e-7; | |
#[derive(Debug, Clone)] | |
pub struct QuantileTransformer { | |
references: Vec<f64>, | |
quantiles: Vec<Vec<f64>>, | |
} | |
fn transform_col(x: f64, quantiles: &Vec<f64>, references: &Vec<f64>) -> f64 { | |
let y; | |
let xlb = quantiles[0]; | |
let xub = *quantiles.last().unwrap(); | |
if x <= xlb { | |
y = 0.0; | |
} else if x >= xub { | |
y = 1.0; | |
} else { | |
// xの左右を二分探索で探す | |
let mut ilb = 0; | |
let mut iub = quantiles.len() - 1; | |
while iub - ilb > 1 { | |
let imd = (ilb + iub) / 2; | |
let qmd = quantiles[imd]; | |
if qmd < x { | |
ilb = imd; | |
} else { | |
iub = imd; | |
} | |
} | |
assert!(quantiles[ilb] <= x); | |
assert!(quantiles[iub] >= x); | |
// 線形補間する | |
let xlb = quantiles[ilb]; | |
let xub = quantiles[iub]; | |
let dlb = x - xlb; | |
let dub = xub - x; | |
let wlb = dub / (dlb + dub); | |
let wub = dlb / (dlb + dub); | |
dbg!(wlb, wub); | |
y = references[ilb] * wlb + references[iub] * wub; | |
} | |
let y = y.clamp( | |
BOUNDS_THRESHOLD - f64::EPSILON, | |
1.0 - (BOUNDS_THRESHOLD - f64::EPSILON), | |
); | |
dbg!(y); | |
normal_distribution::ppf(y) | |
} | |
#[derive(serde::Deserialize)] | |
struct Dump { | |
output_distribution: String, | |
references_: Vec<f64>, | |
quantiles_: Vec<Vec<f64>>, | |
} | |
impl QuantileTransformer { | |
pub fn from_dump(dump: serde_json::Value) -> R<QuantileTransformer> { | |
let dump: Dump = serde_json::from_value(dump)?; | |
// normalしかサポートしない | |
assert_eq!(dump.output_distribution, "normal"); | |
// 転置しといたほうが便利、ってか元のsklearnの実装も転置しといたほうが便利に見えて仕方ないのに何で転置してないんだろ | |
let n_features = dump.quantiles_[0].len(); | |
let n_references = dump.references_.len(); | |
let mut quantiles = vec![vec![0.0; n_references]; n_features]; | |
for i in 0..n_features { | |
for j in 0..n_references { | |
quantiles[i][j] = dump.quantiles_[j][i]; | |
} | |
} | |
// quantilesがユニークじゃない場合は結構変な処理しないといけないが、冷静に俺はそういうの使う予定ないから落とす | |
for qs in quantiles.iter_mut() { | |
qs.dedup(); | |
assert_eq!(qs.len(), n_references); | |
} | |
Ok(QuantileTransformer { | |
references: dump.references_, | |
quantiles, | |
}) | |
} | |
pub fn transform(&self, x: &[f64]) -> Vec<f64> { | |
assert_eq!(x.len(), self.quantiles.len()); | |
x.iter() | |
.zip(self.quantiles.iter()) | |
.map(|(x, quantiles)| transform_col(*x, quantiles, &self.references)) | |
.collect() | |
} | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
fn create() -> QuantileTransformer { | |
let j = serde_json::json!({ | |
"n_quantiles": 10, | |
"output_distribution": "normal", | |
"ignore_implicit_zeros": false, | |
"subsample": 100000, | |
"random_state": null, | |
"copy": true, | |
"n_features_in_": 3, | |
"n_quantiles_": 10, | |
"references_": [ | |
0.0, | |
0.1111111111111111, | |
0.2222222222222222, | |
0.3333333333333333, | |
0.4444444444444444, | |
0.5555555555555556, | |
0.6666666666666666, | |
0.7777777777777777, | |
0.8888888888888888, | |
1.0, | |
], | |
"quantiles_": [ | |
[-2.613328303719778, -9.321205290168217, 0.13674275875021852], | |
[-1.1414101426254262, -3.294752914303114, 0.31128319201392723], | |
[-0.7306000262128471, -1.919245430597138, 0.5628969761610223], | |
[-0.427117826339735, -0.7238972474428322, 0.7805069173920551], | |
[-0.05639286493061981, 0.885189368484695, 0.9814794091618718], | |
[0.1722755897532091, 1.991170360919146, 1.2760902770308562], | |
[0.39748654705760395, 3.8665254431411706, 1.6340155327513783], | |
[0.67348557453346, 5.3665644352138, 1.959333216365857], | |
[0.9437791586194102, 7.205771950040178, 2.8721836987270755], | |
[2.4620269142769113, 12.26632790171583, 17.47969633690788], | |
], | |
}); | |
QuantileTransformer::from_dump(j).unwrap() | |
} | |
fn check(qt: &QuantileTransformer, x: &[f64], y: &[f64]) { | |
let z = qt.transform(x); | |
assert_eq!(y.len(), z.len()); | |
for (a, b) in y.iter().zip(z.iter()) { | |
assert_approx_eq!(a, b); | |
} | |
} | |
#[test] | |
fn test_references() { | |
let qt = create(); | |
let cases = &[ | |
( | |
[-2.613328303719778, -9.321205290168217, 0.13674275875021852], | |
[-5.199337582605575, -5.199337582605575, -5.199337582605575], | |
), | |
( | |
[-1.1414101426254262, -3.294752914303114, 0.31128319201392723], | |
[-1.22064034884735, -1.22064034884735, -1.22064034884735], | |
), | |
( | |
[-0.7306000262128471, -1.919245430597138, 0.5628969761610223], | |
[-0.764709673786387, -0.764709673786387, -0.764709673786387], | |
), | |
( | |
[-0.427117826339735, -0.7238972474428322, 0.7805069173920551], | |
[-0.430727299295457, -0.430727299295457, -0.430727299295457], | |
), | |
( | |
[-0.05639286493061981, 0.885189368484695, 0.9814794091618718], | |
[-0.139710298881862, -0.139710298881862, -0.139710298881862], | |
), | |
( | |
[0.1722755897532091, 1.991170360919146, 1.2760902770308562], | |
[0.1397102988818621, 0.1397102988818621, 0.1397102988818621], | |
), | |
( | |
[0.39748654705760395, 3.8665254431411706, 1.6340155327513783], | |
[0.4307272992954574, 0.4307272992954574, 0.4307272992954574], | |
), | |
( | |
[0.67348557453346, 5.3665644352138, 1.959333216365857], | |
[0.7647096737863867, 0.7647096737863867, 0.7647096737863867], | |
), | |
( | |
[0.9437791586194102, 7.205771950040178, 2.8721836987270755], | |
[1.2206403488473496, 1.2206403488473496, 1.2206403488473496], | |
), | |
( | |
[2.4620269142769113, 12.26632790171583, 17.47969633690788], | |
[5.19933758270342, 5.19933758270342, 5.19933758270342], | |
), | |
]; | |
for case in cases { | |
check(&qt, &case.0, &case.1); | |
} | |
} | |
#[test] | |
fn test_random() { | |
let qt = create(); | |
let cases = &[ | |
( | |
[9.543169032696227, 13.004265458495848, 1.2817188786196336], | |
[5.19933758270342, 5.19933758270342, 0.14413444750289997], | |
), | |
( | |
[6.148524766915234, 12.277375912423363, -2.0032741621514276], | |
[5.19933758270342, 5.19933758270342, -5.199337582605575], | |
), | |
( | |
[3.2479687586644044, 6.333366933638192, -0.18854883910825748], | |
[5.19933758270342, 0.9788976648571445, -5.199337582605575], | |
), | |
( | |
[12.026790240804463, -8.17807192753417, 7.139275629961631], | |
[5.19933758270342, -2.0320120702704356, 1.414185316101896], | |
), | |
( | |
[11.297395996011659, 12.389898901401988, 4.3528110754416325], | |
[5.19933758270342, 5.19933758270342, 1.2824135105950227], | |
), | |
( | |
[-5.548440019781655, 3.9141903012544503, 8.462175829374196], | |
[-5.199337582605575, 0.44045804894271895, 1.4863658158773896], | |
), | |
( | |
[0.48259293083427224, 1.8454399739866432, -8.433798779937739], | |
[0.5270731855886761, 0.10273894814593508, -5.199337582605575], | |
), | |
( | |
[2.9464268070547117, 12.73414827971353, 5.139387737583421], | |
[5.19933758270342, 5.19933758270342, 1.3173195432982814], | |
), | |
( | |
[-3.4636253585345047, 13.086601949163505, 2.851517222764288], | |
[-5.199337582605575, 5.19933758270342, 1.207464726166755], | |
), | |
( | |
[-6.081831945801352, -6.336175884405989, -7.420982125992383], | |
[-5.199337582605575, -1.5978724353638714, -5.199337582605575], | |
), | |
( | |
[-6.275065630029899, 7.586341240789082, 1.7097115253806585], | |
[-5.199337582605575, 1.266007404810665, 0.5030071199147996], | |
), | |
( | |
[-0.13038122929708784, -4.161259056492385, 5.808837900045351], | |
[-0.1960917782971291, -1.3097800036623548, 1.3483455214298523], | |
), | |
( | |
[7.618117371577743, 5.909956441138961, 10.508511337022796], | |
[5.19933758270342, 0.8801295864809635, 1.6161969123723285], | |
), | |
( | |
[8.693866471169883, -0.7330487877831544, 9.163553125519751], | |
[5.19933758270342, -0.4330680381712861, 1.5280004611048847], | |
), | |
( | |
[0.4282314588607754, -1.1967112585064736, -6.889187183469036], | |
[0.46502685937014465, -0.5551855155241495, -5.199337582605575], | |
), | |
( | |
[7.685794675525781, 4.500672464040191, 14.250996390962012], | |
[5.19933758270342, 0.5640480936891384, 1.967567697630419], | |
), | |
( | |
[1.5798977254115556, -9.611474340719736, -6.600078504292563], | |
[1.5176006671813367, -5.199337582605575, -5.199337582605575], | |
), | |
( | |
[0.17534066429800532, -3.5008980831264935, 12.394201107918992], | |
[0.1435390280777139, -1.2409594504784927, 1.7661839743843688], | |
), | |
( | |
[3.5825769280006945, -1.551988325695696, -6.6376732992024206], | |
[5.19933758270342, -0.6546087494691977, -5.199337582605575], | |
), | |
( | |
[-4.10776763251768, -9.18756794697149, 9.369452587399543], | |
[-5.199337582605575, -2.811715563246142, 1.5407399020190642], | |
), | |
]; | |
for case in cases { | |
check(&qt, &case.0, &case.1); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment