Skip to content

Instantly share code, notes, and snippets.

@nobuta05
Created June 27, 2020 15:23
Show Gist options
  • Save nobuta05/85c6e5087ad25db644f93da6d3f7be68 to your computer and use it in GitHub Desktop.
Save nobuta05/85c6e5087ad25db644f93da6d3f7be68 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Scaled Conjugate Gradient(SCG)の実装\n",
"\n",
"本資料は、 Meiller, \"A Scaled Conjugate Gradient Algorithm for Fast Supervised Learning\" (1991) の実装メモです。\n",
"カーネル法等のハイパーパラメタを最適化するために、私はこれまで共役勾配法やBFGS法などを試していました。ただlinesearch手法の選択等が気になっていたため、\n",
"シンプルでそれなりに良い性能の最適化手法を探していました。\n",
"\n",
"SCGはステップ幅の選択にlinesearchを行わず、目的関数を二次近似した際の最適幅を用います。その際、近似精度が悪い時はステップ幅を小さくするという\n",
"信頼領域法の考えも取り入れています。アイディアがシンプルで実装が簡単な割に良い性能を確認しています。論文を読む時に自分がひっかかった点をメモする意味で本資料を作成しています。\n",
"\n",
"なお共役勾配法の概要を学びたい時は『機械学習のための連続最適化』をお勧めします。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SCGの概要\n",
"\n",
"実装してみた感想として、SCGは **信頼領域法** と **2次関数の共役勾配法** の合わせ技と思いました。論文内で著者自身は以下のように記述しています。\n",
"> The idea is to combine the model-trust region approach, known from the Levenberg-Marquardt algorithm, 6 with the conjugate gradient approach.\n",
"\n",
"SCGの特徴は、調整すべきハイパーパラメタがほとんどない点です。つまりステップ幅を得るために必要なline searchの実行が不要となる手法です。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 前準備. 凸2次関数における共役勾配法\n",
"\n",
"SCGは目的関数を局所的に2次近似して導出されるため、簡単に凸2次関数の共役勾配法を見直します。\n",
"\n",
"目的関数$f: \\mathbb{R}^d \\to \\mathbb{R}$を凸2次関数、つまり以下を満たすとします。\n",
"$$\n",
"\\begin{align}\n",
" f(x) &\\coloneqq \\frac{1}{2}x^{\\top} Ax + b^{\\top}x + c,\n",
"\\\\\n",
" &\\text{where} \\quad \\left\\{\n",
" \\begin{aligned}\n",
" &A \\in \\mathbb{R}^{d\\times d} \\quad \\text{正定値行列},\n",
" \\\\\n",
" &b \\in \\mathbb{R}^d,\n",
" \\\\\n",
" &c \\in \\mathbb{R}\n",
" \\end{aligned}\n",
" \\right.\n",
"\\end{align}\n",
"$$\n",
"この目的関数を最小とする$x^{*}\\in\\mathbb{R}^d$を求める問題を考えます。あるベクトル集合$\\left\\{ v_i \\right\\}_{i=1}^{l} \\subset \\mathbb{R}^d,\\; l\\leq d$が行列$A$に関して共役である、等の用語の解説は『機械学習のための連続最適化』を参照していただくこととします。ここではステップ幅の選択について記述します。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"$k$ステップ目の暫定解を$x_{k}$、共役勾配方向を$g_k\\in\\mathbb{R}^d$で表し、この時の最適なステップ幅を考えます。つまり、\n",
"$$\n",
"\\min_{\\alpha\\in\\mathbb{R}} f(x_k + \\alpha g_k) = F(\\alpha)\n",
"$$\n",
"を求めます。途中式は省略しますが$\\frac{dF}{d\\alpha} = 0$を満たす$\\alpha$を求めると、\n",
"$$\n",
"\\alpha_k = \\frac{g_k^{\\top} \\left( -\\nabla f(x_k) \\right)}{g_k^{\\top}A g_k}\n",
"$$\n",
"が最適解であることがわかります。SCGのステップ幅はこの形をしていることから、SCGでは目的関数を二次近似してステップ幅を得ていると考えられます。ただし行列$A$に相当する一般的な関数のヘッセ行列を求めることは面倒な場合が多いので、SCGでは$A g_k$のベクトル値を近似しています。$A$を近似していないことを留意する必要があります。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SCGのポイント1. ステップ幅の選択\n",
"\n",
"冒頭に述べたように、SCGは信頼領域法と2次関数の共役勾配法の合わせ技であると私は認識しています。この節ではその後者のポイントを確認します。\n",
"\n",
"ステップ幅を計算するために、SCGでは目的関数$f(x)$を二次近似します。\n",
"$$\n",
"f(x+y)\\approx f(x) + \\nabla f(x)^{\\top} y + \\frac{1}{2} y^{\\top} \\nabla^2 f(x) y\n",
"$$\n",
"\n",
"SCGのステップ幅導出でポイントとなるのは$\\nabla^2 f(x) y$の近似方法です。\n",
"$\\epsilon > 0,\\; \\frac{1}{\\epsilon} \\left( \\nabla f(x+\\epsilon y) - \\nabla f(x) \\right)$において$\\epsilon \\to 0$の極限は$\\nabla^2 f(x) y$となります。著者曰く、単にこの近似を用いても安定したアルゴリズムを得ることはできなかったとのことです。アルゴリズムが不安定になるのは、ヘッセ行列が半負定値となることが場合がある目的関数であったり、暫定解が最適解から遠く離れている時が挙げられると著者は述べています。\n",
"\n",
"前節の凸2次関数において$\\frac{dF}{d\\alpha} = 0$を満たす$\\alpha$でステップ幅を選択できた理由は、行列$A$が正定値だからです。仮に負定値だった場合は目的関数値が増加する方向へ進んでしまいます。つまり$g_k^{\\top} \\nabla^2 f(x) g_k$が正の値となるよう制約を加えて近似することが、SCGのアプローチです。SCGでは$\\nabla^2 f(x_k) g_k$を以下のように近似します。\n",
"$$\n",
"\\nabla^2 f(x_k) g_k \\approx \\tilde{s}_k \\coloneqq \\frac{\\nabla f(x_k + \\sigma_k g_k) - \\nabla f(x_k)}{\\sigma_k} + \\lambda_k g_k,\n",
"$$\n",
"ここで$\\lambda_k > 0$は補正項を表し、論文において初期値は$\\lambda_0 = 0$となっています。各ステップにおいて、$g_k^{\\top} \\tilde{s}_k$が正の値となるように更新します。論文上の$\\bar{\\lambda}_k$が更新値となり、論文上では導出された下界の2倍の値で更新しています。\n",
"\n",
"このようにして、凸2次関数の場合におけるステップ幅の分母の値を近似して、各ステップのステップ幅としています。 **補正項が大きくなるほど、ステップ幅は小さくなります。**"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## SCGのポイント2. 近似精度指標\n",
"\n",
"前節ではSCGのステップ幅選択では、目的関数を二次近似して導出していることを確認しました。ここでは信頼領域法のポイントを確認します。具体的には、二次近似の精度が悪い時はどうするのかを確認します。\n",
"\n",
"前節によってステップ幅は決定しました。しかし更新する前に、近似精度の確認を行います。論文の式(26)一行目の値は、元々の目的関数と近似した2次関数での目的関数における、更新前後の差の比を表します。ただし、近似した2次関数は実際にヘッセ行列を導出して計算するわけではないことを留意してください。したがって式(26)の二行目の等号は不適切で、実際には$\\approx$が適当かと思います。\n",
"\n",
"近似2次関数が更新後も十分な近似精度であるならば、$\\Delta_k$は$1$に近い値をとり、不十分であれば$0$または負の値となる状況も考えられます。\n",
"\n",
"論文のアルゴリズムを見ると、$\\Delta_k$が$\\frac{3}{4}$以上だったり、$\\frac{1}{4}$未満だったりで補正項を調整しています。この値は信頼領域法の枠組みから得られる値のようですが、私はまだここを追えていません。ともあれ、現ステップにおいて近似精度が十分であれば、次ステップでは補正項を小さくしてより大きなステップ幅を得やすくし、不十分な近似精度であれば補正項を大きくすることでステップ幅を小さくする、という信頼領域法の考えが取り入れられていると認識していれば良いかなと思っています。"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 総論\n",
"\n",
"以上を認識すれば論文のアルゴリズムはすんなり把握できると思います。\n",
"最後に論文のアルゴリズムをJuliaで書いただけの実装メモを記述します。"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Juliaでの実装例\n",
"\n",
"# d\\in\\mathcal{N} ... 最適化する変数の個数(次元数)\n",
"# f: \\mathbb{R}^d -> \\mathbb{R} ... 最小化する目的関数\n",
"# ∇f: \\mathbb{R}^d -> \\mathbb{R}^d ... 目的関数の勾配\n",
"function SCG(init, f::Function, ∇f::Function; σ=1e-4, MaxLoop=50)\n",
" D = length(init)\n",
" xₖ = zeros(D)\n",
" xₖ .= init\n",
" λₖ = 1e-6\n",
" λ̄ₖ = 0.0\n",
" rₖ = -∇f(xₖ)\n",
" pₖ = rₖ\n",
" success = true\n",
"\n",
" for loop in 1:MaxLoop\n",
" if success\n",
" σₖ = σ/norm(pₖ)\n",
" sₖ = (∇f(xₖ+σₖ.*pₖ) - ∇f(xₖ)) / σₖ\n",
" δₖ = dot(pₖ, sₖ)\n",
" end\n",
"\n",
" δₖ = δₖ + (λₖ - λ̄ₖ)*norm(pₖ)^2\n",
" if δₖ <= 0.0\n",
" λ̄ₖ = 2*(λₖ - δₖ/norm(pₖ)^2)\n",
" δₖ = -δₖ + λₖ*norm(pₖ)^2\n",
" λₖ = λ̄ₖ\n",
" end\n",
"\n",
" μₖ = dot(pₖ, rₖ)\n",
" αₖ = μₖ/δₖ\n",
" Δₖ = 2*δₖ*( f(xₖ) - f(xₖ+αₖ.*pₖ) )/μₖ^2\n",
" if Δₖ >= 0.0\n",
" xₛ = xₖ+αₖ.*pₖ\n",
" rₛ = - ∇f(xₛ)\n",
" λ̄ₛ = 0.0\n",
" success = true\n",
" if loop % D == 0\n",
" pₛ = rₛ\n",
" else\n",
" βₖ = (norm(rₛ)^2 - dot(rₛ, rₖ)) / μₖ\n",
" pₛ = rₛ + βₖ.*pₖ\n",
" end\n",
"\n",
" if Δₖ >= 0.75\n",
" λₖ = λₖ/4.0\n",
" end\n",
" xₖ = xₛ\n",
" rₖ = rₛ\n",
" pₖ = pₛ\n",
" λ̄ₖ = λ̄ₛ\n",
" else\n",
" λ̄ₖ = λₖ\n",
" end\n",
"\n",
" if Δₖ < 0.25\n",
" λₖ = λₖ + δₖ*(1-Δₖ)/norm(pₖ)^2\n",
" end\n",
" end\n",
"\n",
" return(xₖ)\n",
"end"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.4.2",
"language": "julia",
"name": "julia-1.4"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.4.2"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment