2011/06/19からのアクセス回数 5140
ここで紹介したSageワークシートは、以下のURLからダウンロードできます。
http://www15191ue.sakura.ne.jp:8000/home/pub/8/
また、Sageのサーバを公開しているサイト(http://www.sagenb.org/, http://www15191ue.sakura.ne.jp:8000/)にユーザIDを作成することで、ダウンロードしたワークシートを アップロードし、実行したり、変更していろいろ動きを試すことができます。
author>Hiroshi TAKEMOTO</author> (<email>take@pwv.co.jp</email>)
与えられたデータに最もよくフィットする関数を求めるのが、「データフィッティング」です。 今回は、$ sin(2 \pi x) $曲線に正規分布のノイズを加えたテストデータを3次多項式で 近似しながら、Sageでのデータフィッティングの方法と手法のもつ問題点について説明します。
テストデータは、区間[0, 1]でランダムに抽出したxに対して、以下の式で与えられるyを計算して生成します。 $$ y = sin(2 \pi x) + \mathcal{N}(0,0.3) $$
変数X, Yにそれぞれ10個のx座標、y座標のリストをセットします。 このままで計算するたびにX, Yの値が異なってしまうので、一度作成した値をsave関数でワークシートに保存し、 load関数で読み込んでいます。
sageへの入力:
# テストデータ生成 #X = [random() for i in range (10)] #Y = [sin(2*pi*x) + gauss(0, 0.3) for x in X] # データの保存 #save(X, DATA+'X') #save(Y, DATA+'Y') # データのリストア X = load(DATA+'X') Y = load(DATA+'Y')
生成されたデータをもとの\(sin(2 \pi x) \)曲線と一緒にプロットしてみます。 後で、少ないデータでのフィッティングに使うため、最初の3点の座標は赤で、 残りを青でプロットしています。
sageへの入力:
# データのプロット x = var('x') sin_plt = plot(sin(2*pi*x),[x, 0, 1], rgbcolor='green') blue_plt = list_plot(zip(X[3:], Y[3:])) red_plt = list_plot(zip(X[:3], Y[:3]), rgbcolor='red') data_plt = list_plot(zip(X, Y)); (blue_plt + red_plt + sin_plt).show(xmin=0, xmax= 1, ymin=-1.5, ymax=1.5)
#pre{{
}}
Sageが提供している曲線のフィッティング関数がfind_fit関数です。find_fit関数の使い方を以下に示します。
find_fit(データ, モデル, オプション)
データには、モデルで使用する変数の値(ここではx)リストと観測地(ここではy)の組(タプルまたはリスト)の リストを渡し、モデルには変数を与えると予測値を計算するモデルの関数を渡します。
今回の例では、modelとして、以下の3次多項式を使用します。 $$ model(x) = w_0 + w_1 x + w_2 x^2 + w_3 x^3 $$
find_fitで求められた\(w_0, w_1, w_2, w_3\)を使って関数f_fitを定義するには、 辞書型で結果をもらうと便利です。solution_dict=Trueオプションを指定すると計算結果が 辞書型で返されます。
sageへの入力:
(w0, w1, w2, w3) = var('w0 w1 w2 w3') model(x) = w0 + w1*x + w2*x^2 + w3*x^3 data = zip(X, Y) fit = find_fit(data, model, solution_dict=True); view(fit)
#pre{{
}}
モデルmodelにfind_fitの結果fitを代入して、フィッティング曲線の関数f_fitを定義します。
赤で表された曲線がデータとうまく合っていることが見て取れます。
sageへの入力:
f_fit(x) = model.subs(fit) fit_plt = plot(f_fit, [x, 0, 1], rgbcolor='red') (blue_plt + red_plt + sin_plt + fit_plt).show(xmin=0, xmax= 1, ymin=-1.5, ymax=1.5)
モデルmodelでは、4個の未知数$w_0, w_1, w_2, w_3$を求めていますが、 データの個数が4個よりも少ない場合には、どのようになるのか試してみましょう。
10個のデータのうち最初の3個(赤でプロットされた点)を使ってfind_fitを実行すると、 以下のようにエラーとなってしまいます。
sageへの入力:
# データ数が足りない場合には、エラーとなる data = zip(X[:3], Y[:3]; print len(data) fit = find_fit(data, model, solution_dict=True); print fit f_fit(x) = model.subs(fit) fit_plt = plot(f_fit, [x, 0, 1]) (blue_plt + red_plt + sin_plt + fit_plt).show()
Traceback (most recent call last): ... SyntaxError: invalid syntax
多項式によるフィッティング関数yを以下のように定義し、 $$ y(x, w) = \sum^{M}_{j=0} w_j x^j $$ 各項の\(x^n\)を以下のように表すと、 $$ \phi_j(x) = x^j $$ 観測値YとXの関係を行列で表すと以下の式でもっと誤差の少ない重みwを求めることが最小二乗法の目的です。 $$ Y \approx \Phi w $$ ここで、\(\Phi\)の各要素は、以下の式で表されます。 $$ \Phi_{nj} = \phi_j(x_n) $$
これなら分かる最適化数学 によると、 ムーア・ベンローズの一般逆行列\(\Phi^{\dagger}\)を使うと、重みwは次のように表すことができます。 $$ w = \Phi^{\dagger} Y $$ $\Phi^{\dagger}$は、データの個数n、未知数wの個数mとすると以下のようになります。 $$ \Phi^{\dagger} = \left\{\begin{eqnarray} ( \Phi^T \Phi )^{-1} \Phi^T & m > n \\ \Phi^T (\Phi \Phi^T)^{-1} & m < n \end{eqnarray}\right. $$
ムーア・ベンローズの一般逆行列\(\Phi^{\dagger}\)を求めることで、データ数が足りない場合でも フィッティング関数を求めることができます。
sageへの入力:
# 次数のセット M = 3 # Φ関数定義 def _phi(x, j): return x^j
最初にデータ数が10個の場合(n > m)の場合の、ムーア・ベンローズの一般逆行列\(\Phi^{\dagger}\)をSageを使って求めてみましょう。
\(\Phi^{\dagger}\)をほとんど定義と同じような形でSageで書き表すことができることに注目してください。
sageへの入力:
# 計画行列Φ Phi = matrix([[ _phi(x,j) for j in range(0, (M+1))] for x in X]); Phi_t = Phi.transpose() # ムーア・ベンローズの一般逆行列 Phi_dag = (Phi_t * Phi).inverse() * Phi_t; # 平均の重み Wml = Phi_dag * vector(Y)
次に、求まった重みWmlを使って多項式y(x)を以下のように定義します。
sageへの入力:
# 出力関数yの定義 y = lambda x : sum(Wml[i]*x^i for i in (0..M));
データ数N=10, 未知数M=4の時の、多項式回帰の結果(赤)をサンプリング(青)とオリジナルの\(sin(2 \pi x)\)を合わせて プロットします。
sageへの入力:
y_plt = plot(y, [x, 0, 1], rgbcolor='red'); (y_plt + data_plt + sin_plt).show(xmin=0, xmax= 1, ymin=-1.5, ymax=1.5);
それでは、データ数N=3, 未知数M=4の時のムーア・ベンローズの一般逆行列\(\Phi^{\dagger}\)を計算し、 その結果をプロットしてみましょう。
sageへの入力:
# 計画行列Φ Phi = matrix([[ _phi(x,j) for j in range(0, (M+1))] for x in X[:3]]) Phi_t = Phi.transpose() # ムーア・ベンローズの一般逆行列 PhiPhit = Phi*Phi_t Phi_dag = Phi_t*PhiPhit.inverse(); view(Phi_dag)
sageへの入力:
# 平均の重み Wml = Phi_dag * vector(Y[:3]); Wml.apply_map(lambda x : n(x))
(0.100463412348611, 7.01017348792056, -5.25466361108248, -7.16736911572035)
与えられた3点を通る3次曲線が求まっていることがわかります。
sageへの入力:
x = var('x') y_plt = plot(y, [x, -2, 1.5], rgbcolor='blue') sin_plt = plot(sin(2*pi*x),[x, -2, 1.5], rgbcolor='green') (y_plt + red_plt + sin_plt).show()
皆様のご意見、ご希望をお待ちしております。