Created
December 8, 2012 07:22
-
-
Save zaki50/4239101 to your computer and use it in GitHub Desktop.
与えられた座標を3次スプライン補間します。
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright (C) 2012 Makoto Yamazaki <[email protected]> | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package org.zakky.interpolator; | |
import org.apache.commons.math3.linear.Array2DRowRealMatrix; | |
import org.apache.commons.math3.linear.ArrayRealVector; | |
import org.apache.commons.math3.linear.DecompositionSolver; | |
import org.apache.commons.math3.linear.LUDecomposition; | |
import org.apache.commons.math3.linear.RealMatrix; | |
import org.apache.commons.math3.linear.RealVector; | |
/** | |
* http://www.akita-nct.ac.jp/yamamoto/lecture/2004/5E/interpolation/text/html/node3.html 読んで書いた | |
* http://commons.apache.org/math から commons-math3-3.0 を持ってきてクラスパスを通してください。 | |
* | |
* 3次スプライン補間なので、N + 1 個の (x, y) 座標の組から N 個の3次関数を作成して補完します。 | |
* また、与えられた座標において、隣接する3次関数の1次導関数と2次導関数が等しくなります。 | |
*/ | |
public class SplineInterpolator { | |
/** | |
* 座標列によって区切られる区間(3次関数)の数 | |
*/ | |
private final int N; | |
/** | |
* X 座標列 | |
*/ | |
private final double[] mXCoordinates; | |
/** | |
* Y 座標列 | |
*/ | |
private final double[] mYCoordinates; | |
/* | |
* 補間のための3次関数群の係数情報 | |
* | |
* 使用する関数のインデックスを j とすると、 0 <= j < N で、 | |
* y = mA[j](x-x[j])^3 + mB[j](x-x[j])^2 + mC[j](x-x[j])^1 + mD[j] | |
*/ | |
/** | |
* 3次の項に対する係数列。長さ N。 | |
*/ | |
private final double[] mA; | |
/** | |
* 2次の項に対する係数列。長さ N。 | |
*/ | |
private final double[] mB; | |
/** | |
* 1次の項に対する係数列。長さ N。 | |
*/ | |
private final double[] mC; | |
/** | |
* 0次の項に対する係数列。長さ N。 | |
*/ | |
private final double[] mD; | |
public SplineInterpolator(double[] xCoordinates, double[] yCoordinates) { | |
this(xCoordinates, yCoordinates, xCoordinates.length); | |
} | |
public SplineInterpolator(double[] xCoordinates, double[] yCoordinates, int length) { | |
super(); | |
ensureInputsAreValid(xCoordinates, yCoordinates, length); | |
N = length - 1; | |
mXCoordinates = new double[N + 1]; | |
System.arraycopy(xCoordinates, 0, mXCoordinates, 0, length); | |
mYCoordinates = new double[N + 1]; | |
System.arraycopy(yCoordinates, 0, mYCoordinates, 0, length); | |
mA = new double[N]; | |
mB = new double[N]; | |
mC = new double[N]; | |
mD = new double[N]; | |
// fill mA, mB, mC, mD | |
calculate(); | |
} | |
/** | |
* 与えられた {@code x} に対する y の値を返します。 | |
* | |
* @param x x値。 | |
* @return (補間によって計算された) y の値。 | |
*/ | |
public double get(double x) { | |
final double[] xCoordinates = mXCoordinates; | |
// FIXME 区間がたくさんある場合は NavigableMap とか使った方がいい | |
int targetFuntionIndex = N - 1; // 見つからない場合は最後のものを使う | |
for (int j = 0; j < N; j++) { | |
if (x <= xCoordinates[j + 1]) { | |
targetFuntionIndex = j; | |
break; | |
} | |
} | |
// (x - x_j)^1 | |
final double x1 = x - xCoordinates[targetFuntionIndex]; | |
// (x - x_j)^2 | |
final double x2 = x1 * x1; | |
// (x - x_j)^3 | |
final double x3 = x2 * x1; | |
final double y = mA[targetFuntionIndex] * x3 | |
+ mB[targetFuntionIndex] * x2 | |
+ mC[targetFuntionIndex] * x1 | |
+ mD[targetFuntionIndex]; | |
return y; | |
} | |
private static void ensureInputsAreValid(double[] xCoordinates, double[] yCoordinates, | |
int length) { | |
if (length < 2) { | |
throw new IllegalArgumentException("'length' must be 2 or more."); | |
} | |
if (xCoordinates.length < length) { | |
throw new IllegalArgumentException( | |
"length of 'xCoordinates' must not be less than 'length'"); | |
} | |
if (yCoordinates.length < length) { | |
throw new IllegalArgumentException( | |
"length of 'yCoordinates' must not be less than 'length'"); | |
} | |
// x が単調増加であることと、座標にNaN や無限大が含まれないことを確認 | |
double prevX = Double.NEGATIVE_INFINITY; | |
for (int i = 0; i < length; i++) { | |
final double x = xCoordinates[i]; | |
final double y = yCoordinates[i]; | |
if (Double.isNaN(x) || Double.isInfinite(x)) { | |
throw new IllegalArgumentException( | |
"'xCoodinates' must not contain NaN nor INFINITY"); | |
} | |
if (Double.isNaN(y) || Double.isInfinite(y)) { | |
throw new IllegalArgumentException( | |
"'yCoodinates' must not contain NaN nor INFINITY"); | |
} | |
if (x <= prevX) { | |
throw new IllegalArgumentException( | |
"elements in 'xCoodinates' must monotonically increase"); | |
} | |
} | |
} | |
/** | |
* 補間に必要な値を計算します。 | |
*/ | |
private void calculate() { | |
final double[] xCoordinates = mXCoordinates; | |
final double[] yCoordinates = mYCoordinates; | |
final double[] h = new double[N]; | |
for (int j = 0; j < N; j++) { | |
h[j] = xCoordinates[j + 1] - xCoordinates[j]; | |
} | |
final RealMatrix coefficients = buildCoefficients(h); | |
final RealVector constants = buildConstants(h); | |
// 行列式を解いて u[1] から u[N - 1] までを求める。 | |
final DecompositionSolver solver = new LUDecomposition(coefficients).getSolver(); | |
// u の index はずれているので、値を取り出す時は getUAt(int) を使うこと | |
final RealVector u = solver.solve(constants); | |
// mA, mB, mC, mD の値を求める | |
for (int j = 0, length = N; j < length; j++) { | |
final double u_j = getUAt(u, j); | |
final double u_j1 = getUAt(u, j + 1); | |
final double y_j = yCoordinates[j]; | |
final double y_j1 = yCoordinates[j + 1]; | |
mA[j] = (u_j1 - u_j) / (6d * (h[j])); | |
mB[j] = u_j / 2d; | |
mC[j] = ((y_j1 - y_j) / h[j]) - ((h[j] * (2d * u_j + u_j1)) / 6d); | |
mD[j] = y_j; | |
} | |
} | |
/** | |
* 補間関数の係数を求める際に使用する u[] を計算する際の行列式の係数行列(AU = B の A) | |
* を構築します。 | |
* | |
* @param h {@code x[j + 1] - x[j]}の配列 | |
* @return 係数行列。 | |
*/ | |
private RealMatrix buildCoefficients(final double[] h) { | |
final RealMatrix coefficients = new Array2DRowRealMatrix(N - 1, N - 1); | |
for (int rowIndex = 0; rowIndex < N - 1; rowIndex++) { | |
final double targetH = h[rowIndex]; | |
final double nextH = h[rowIndex + 1]; | |
if (rowIndex != 0) { | |
coefficients.setEntry(rowIndex, rowIndex - 1, targetH); | |
} | |
coefficients.setEntry(rowIndex, rowIndex, 2.0 * (targetH + nextH)); | |
if (rowIndex != N - 2) { | |
coefficients.setEntry(rowIndex, rowIndex + 1, nextH); | |
} | |
} | |
return coefficients; | |
} | |
/** | |
* 補間関数の係数を求める際に使用する u[] を計算する際の行列式の定数列(AU = B の B) | |
* を構築します。 | |
* | |
* @param h {@code x[j + 1] - x[j]}の配列 | |
* @return 定数列。 | |
*/ | |
private RealVector buildConstants(final double[] h) { | |
final double[] yCoordinates = mYCoordinates; | |
final double[] v = new double[N]; // v[0] は使わない | |
double pv = (yCoordinates[1] - yCoordinates[0]) / h[0]; | |
for (int j = 1; j < N; j++) { | |
double temp = (yCoordinates[j + 1] - yCoordinates[j]) / h[j]; | |
v[j] = 6.0 * (temp - pv); | |
pv = temp; | |
} | |
final RealVector constants = new ArrayRealVector(N - 1); | |
for (int j = 1; j < N; j++) { | |
constants.setEntry(j - 1, v[j]); | |
} | |
return constants; | |
} | |
/** | |
* {@code} u が保持する値のインデックスは、他の計算部分に使用するものとズレが | |
* あるので、ズレを吸収してアクセスするためのユーティリティメソッドです。 | |
* u に含まれていない値については、 natural spline になるように補います。 | |
* | |
* @param u u[1] から u[N - 1] の N - 1 個分の u値を保持する {@link RealVector}。 | |
* @param j 計算式上での u のインデックス。 | |
* @return 計算式上での u[j] の値。 | |
*/ | |
private static double getUAt(RealVector u, int j) { | |
if (j == 0 || u.getDimension() < j) { | |
// natural spline なので 0。 | |
return 0d; | |
} | |
return u.getEntry(j - 1); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment