2008/11/04からのアクセス回数 15556
集合知は、Amazonで有名になったレコメンドシステムやGoogle SearchのPageRank アルゴリズムなどをPythonを使って、簡潔にかつ分かりやすく説明した素晴らしい本です。
ここでは、11章で紹介されている「遺伝的プログラミング」をjavaに移植しながら、その アルゴリズムをトレースしてみます。
遺伝的プログラミングでは、プログラムの集合が互いに競り合い、進化します。
この進化には、
があります。
です。
プログラムを進化するために、プログラムをツリー構造で表現します。
集合知の図11-2から引用
このツリーに対応するjavaの関数は、
int func(int x, int y) { if (x > 3) return y + 5; else return y - 2; }
となります。
Gp.javaというファイルに、
javaでは、関数を引数に渡すことができないので、IFunctionインタフェースを持つオブジェクトを渡すことにしました。
interface IFunction { Object eval(List l); } class Node { IFunction function; String name; List children; Node() { this(null, null); } Node(Fwrapper fw, List children) { if (fw != null) { this.function = fw.function; this.name = fw.name; } this.children = children; } protected Object evaluate(List inp) { List results = new ArrayList(); for (int i = 0; i < children.size(); i++) { Node node = (Node)children.get(i); results.add(node.evaluate(inp)); } return this.function.eval(results); } } class Fwrapper extends Node { int childCount; Fwrapper(IFunction function, int childCount, String name) { this.function = function; this.childCount = childCount; this.name = name; } } class ParamNode extends Node { int idx; ParamNode(int idx) { this.idx = idx; } protected Object evaluate(List inp) { return (inp.get(idx)); } } class ConstNode extends Node { Double v; ConstNode(double v) { this.v = v; } protected Object evaluate(List inp) { return (v); } }
Fwapperに渡す関数として、
を用意します。
class AddW implements IFunction { public Object eval(List l) { return (Double)l.get(0) + (Double)l.get(1); } } class SubW implements IFunction { public Object eval(List l) { return (Double)l.get(0) - (Double)l.get(1); } } class MulW implements IFunction { public Object eval(List l) { return (Double)l.get(0) * (Double)l.get(1); } } class IfFunc implements IFunction { public Object eval(List l) { if ((Double)l.get(0) > 0) return l.get(1); else return l.get(2); } } class IsGreater implements IFunction { public Object eval(List l) { if ((Double)l.get(0) > (Double)l.get(1)) return new Double(1); else return new Double(0); } }
AddW, SubW, MulW, Ifw, Gtwの関数を保持するラッパノードを持つリストflistを作成、サンプルツリーを作成する exampletreeメソッドを追加します。
public class Gp { static List flist; static Fwrapper addw; static Fwrapper subw; static Fwrapper mulw; static Fwrapper ifw; static Fwrapper gtw; public Gp() { addw = new Fwrapper(new AddW(), 2, "add"); subw = new Fwrapper(new SubW(), 2, "subtract"); mulw = new Fwrapper(new MulW(), 2, "multiply"); ifw = new Fwrapper(new IfFunc(), 3, "if"); gtw = new Fwrapper(new IsGreater(), 2, "isgreater"); flist = new ArrayList(); flist.add(addw); flist.add(mulw); flist.add(ifw); flist.add(gtw); flist.add(subw); } Node exampletree() { List gtwParList = new ArrayList(); gtwParList.add(new ParamNode(0)); gtwParList.add(new ConstNode(3)); List addwParList = new ArrayList(); addwParList.add(new ParamNode(1)); addwParList.add(new ConstNode(5)); List subwParList = new ArrayList(); subwParList.add(new ParamNode(1)); subwParList.add(new ConstNode(2)); List ifwParList = new ArrayList(); ifwParList.add(new Node(gtw, gtwParList)); ifwParList.add(new Node(addw, addwParList)); ifwParList.add(new Node(subw, subwParList)); return (new Node(ifw, ifwParList)); } }
main関数に記述する代わりに、test1メソッドに最初のテストメソッドをまとめました。
void test1() { Node exampleTree = exampletree(); List para1 = new ArrayList(); para1.add(new Double(2)); para1.add(new Double(3)); System.out.println(exampleTree.evaluate(para1)); List para2 = new ArrayList(); para2.add(new Double(5)); para2.add(new Double(3)); System.out.println(exampleTree.evaluate(para2)); } public static void main(String[] args) { Gp gp = new Gp(); gp.test1(); }
実行すると
1.0 8.0
のように表示されます。
ツリー構造がどのようになっているかを表示するために、Nodeクラスにdisplayメソッドを追加します。
Nodeクラスには、
protected void display(int indent) { for (int i = 0; i < indent; i++) System.out.print(" "); System.out.println(name); for (int i = 0; i < children.size(); i++) { children.get(i).display(indent + 1); } }
ParamNodeクラスには、
protected void display(int indent) { for (int i = 0; i < indent; i++) System.out.print(" "); System.out.format("p%d\n", idx); }
ConstNodeクラスには、
protected void display(int indent) { for (int i = 0; i < indent; i++) System.out.print(" "); System.out.format("%.1f\n", v); }
のようにdisplayメソッドを追加します。
表示の方法は、トップのノードにdisplayメソッドを渡すだけです。
void test2() { Node exampleTree = exampletree(); exampleTree.display(0); }
では、実際にexamletreeで生成されたツリーを表示してみます。 mainメソッドで先ほどのtest1メソッドをtest2に変えて実行すると
if isgreater p0 3.0 add p1 5.0 subtract p1 2.0
ランダムな集合を作成するために、makerandomtreeメソッドを追加します。
Node makeRandamTree(int pc) { return (makeRandamTree(pc, 4, 0.5, 0.6)); } Node makeRandamTree(int pc, int maxdepth, double fpr, double ppr) { if (Math.random() < fpr && maxdepth >0) { Fwrapper f = (Fwrapper)Randam.choid(flist); List children = new ArrayList<Node>(); for (int i = 0; i < f.childCount; i++ ) { Node node = makeRandamTree(pc, maxdepth-1, fpr, ppr); children.add(node); } return (new Node(f, children)); } else if (Math.random() < ppr) { return (new ParamNode(Randam.randint(0, pc-1))); } else return (new ConstNode(Randam.randint(0, 10))); }
makeRandamTreeメソッドのテストとして、test3を作成しました。
void test3() { Node random1 = makeRandamTree(2); System.out.println("random1"); random1.display(0); System.out.println("evaluate"); List para1 = new ArrayList(); para1.add(new Double(7)); para1.add(new Double(1)); System.out.println(random1.evaluate(para1)); List para2 = new ArrayList(); para2.add(new Double(2)); para2.add(new Double(4)); System.out.println(random1.evaluate(para2)); Node random2 = makeRandamTree(2); System.out.println("random2"); random2.display(0); System.out.println("evaluate"); List para3 = new ArrayList(); para3.add(new Double(5)); para3.add(new Double(3)); System.out.println(random2.evaluate(para3)); List para4 = new ArrayList(); para4.add(new Double(5)); para4.add(new Double(20)); System.out.println(random2.evaluate(para4)); }
実行するたびに、異なる結果が得られます。 例として、
random1 subtract add if 2.0 multiply 4.0 3.0 subtract 10.0 p1 9.0 p1 evaluate 20.0 17.0 random2 p1 evaluate 3.0 20.0
のような出力になります。
遺伝的プログラミングのテストとして、単純な数学的関数を推定するテストをします。
ここで、未知の関数として、
Double hiddenFunction(int x, int y) { double val = x*x + 2*y + 3*x + 5; return (new Double(val)); }
を定義します。
この関数を推定するために、200個のデータセットを用意します。 各行は、[Xの値、Yの値、関数値]の3つ組です。
List buildHiddenSet() { List rows = new ArrayList(); for (int i = 0; i < 200; i++) { int x = Randam.randint(0, 40); int y = Randam.randint(0, 40); List cols = new ArrayList(); cols.add(new Double(x)); cols.add(new Double(y)); cols.add(hiddenFunction(x, y)); rows.add(cols); } return (rows); }
遺伝的プログラミングがどの程度正しく推定したかを算出するために、scoreFunctionを以下の様に定義します。
int scoreFunction(Node tree, List s) { double dif = 0; for (int i = 0; i < s.size(); i++) { List cols = (List)s.get(i); double v = Double.parseDouble(tree.evaluate(cols).toString()); dif += Math.abs(v - ((Double)cols.get(2)).doubleValue()); } return ((int)dif); }
スコアが計算を確認するtest4メソッドを以下のように追加します。
void test4() { List hiddenset = buildHiddenSet(); Node random1 = makeRandamTree(2); System.out.println("random1"); random1.display(0); System.out.println("score"); System.out.println(scoreFunction(random1, hiddenset)); Node random2 = makeRandamTree(2); System.out.println("random2"); random2.display(0); System.out.println("score"); System.out.println(scoreFunction(random2, hiddenset)); }
実行結果は、
random1 add 5.0 if add p0 multiply 9.0 8.0 6.0 add p0 8.0 score 133800 random2 add if p0 5.0 subtract p1 add p1 p1 6.0 score 134010
とかなり大きな値になります。
これから、遺伝的プログラミングの突然変異をmutateメソッドに実装します。
mutateでは、
のようにツリーの一部に突然変異を適応します(集合知の図11-4から引用)。
mutateメソッドは以下のようになります。
Node mutate(Node t, int pc) { return (mutate(t, pc, 0.1)); } Node mutate(Node t, int pc, double probchange) { if (Math.random() < probchange) { return (makeRandamTree(pc)); } else { Node result = (Node)t.clone(); if (t.children != null) { List children = new ArrayList(); for (int i = 0; i < t.children.size(); i++) { Node c = t.children.get(i); children.add(mutate(c, pc, probchange)); } result.children = children; } return (result); } }
では、mutateの動作を確認してみましょう。
以下のtest5を追加します。
void test5() { List hiddenset = buildHiddenSet(); Node random2 = makeRandamTree(2); System.out.println("random2"); random2.display(0); System.out.println("random2 socre=" + scoreFunction(random2, hiddenset)); System.out.println("mutate"); Node muttree = mutate(random2, 2); muttree.display(0); System.out.println("muttree socre=" + scoreFunction(muttree, hiddenset)); }
実行結果は、毎回異なります。ほとんど変わらないときと以下のように変化する場合が あります。
random2 subtract p0 8.0 random2 socre=133814 mutate add isgreater 0.0 if add 4.0 p1 p1 8.0 add p1 subtract multiply p1 p0 subtract 9.0 p1 muttree socre=67170
いよいよ最後の交叉を実装します。
crossoverメソッドは以下のようになります(ほとんどpythonのまま)。
Node crossover(Node t1, Node t2) { return (crossover(t1, t2, 0.7, 1)); } Node crossover(Node t1, Node t2, double probswap, int top) { if (Math.random() < probswap && top == 0) { return ((Node)t2.clone()); } else { Node result = (Node)t1.clone(); if (t1.children != null && t2.children != null) { List children = new ArrayList<Node>(); for (int i = 0; i < t1.children.size(); i++) { Node c = t1.children.get(i); children.add(crossover(c, (Node)Randam.choid(t2.children), probswap, 0)); } result.children = children; } return (result); } }
交叉では、ツリーの完全コピーを必要としますので、各ノードクラスにcloneメソッドを追加し、NodeクラスにはdeepCopyメソッドを追加しました。
Nodeクラスの追加
protected void deepCopy(Node dst, Node src) { dst.function = src.function; dst.name = src.name; if (src.children != null) dst.children = new ArrayList<Node>(src.children); } public Object clone() { Node dst = new Node(); dst.deepCopy(dst, this); return (dst); }
ParamNodeクラスの追加
public Object clone() { ParamNode dst = new ParamNode(idx); dst.deepCopy(dst, this); return (dst); }
ConstNodeクラスの追加
public Object clone() { ConstNode dst = new ConstNode(v); dst.deepCopy(dst, this); return (dst); }
Fwrapperクラスの追加
public Object clone() { Fwrapper dst = new Fwrapper(function, childCount, name); dst.deepCopy(dst, this); return (dst); }
また、リストから1個をランダムに抽出するchoiceメソッドをRandam.javaに追加しました。
static Object choid(List list) { if (list != null) { int idx = (int)(list.size()*Math.random()); return (list.get(idx)); } else return (null); }
crossoverの動作確認をするために、test6メソッドを追加します。
void test6() { List hiddenset = buildHiddenSet(); Node randam1 = makeRandamTree(2); System.out.println("random1"); randam1.display(0); Node randam2 = makeRandamTree(2); System.out.println("randam2"); randam2.display(0); Node cross = crossover(randam1, randam2); System.out.println("cross"); cross.display(0); }
何度か動作すると結構おもしろい結果がでます。 以下は一例です。
random1 subtract p0 p1 randam2 add 10.0 add subtract p1 8.0 p1 cross subtract p0 10.0
必要なメソッドがすべて揃ったので、最後に進化するための環境を整えます。
evolveメソッドは、以下の通りです。 終了条件は、
となっています。
void evolve(int pc, int popsize, IRankingFunction raunkingFunction, int maxgen, double mutationRate, double breedingRate, double pexp, double pnew) { List population = new ArrayList(); for (int i = 0; i < popsize; i++) { population.add(makeRandamTree(pc)); } List scores = null; List first; for (int i = 0; i < maxgen; i++) { scores = raunkingFunction.ranking(population); first = (List)scores.get(0); System.out.println(first.get(0).toString()); int score = ((Double)first.get(0)).intValue(); if (score == 0) break; // add top 2 nodes List newop = new ArrayList(); newop.add(first.get(1)); List second = (List)scores.get(1); newop.add(second.get(1)); // generate nest genration. while (newop.size() < popsize) { if (Math.random() > pnew) { List l1 = (List)scores.get(selectindex(pexp)); List l2 = (List)scores.get(selectindex(pexp)); newop.add(mutate( crossover((Node)l1.get(1), (Node)l2.get(1), breedingRate, 1), pc, mutationRate)); } else newop.add(makeRandamTree(pc)); } population = newop; } first = (List)scores.get(0); Node best = (Node)first.get(1); best.display(0); }
pythonではランキング関数をメソッド渡ししているので、IFunctionと同様にインタフェースに変更しました。
interface IRankingFunction { List ranking(List population); } IRankingFunction getRankFunction(List dataSet) { class RankingFunction implements IRankingFunction { List dataSet; RankingFunction(List dataSet) { this.dataSet = dataSet; } public List ranking(List population) { List scores = new ArrayList(); for (int i = 0; i < population.size(); i++) { Node t = (Node)population.get(i); List taple = new ArrayList(); taple.add(new Double(scoreFunction(t, dataSet))); taple.add(t); scores.add(taple); } Collections.sort(scores, new ScoreComparator()); return (scores); } } return (new RankingFunction(dataSet)); }
いよいよ、遺伝的プログラミングの実力を見るときがきました。
void test7() { List hiddenset = buildHiddenSet(); IRankingFunction rk = getRankFunction(hiddenset); evolve(2, 500, rk, 500, 0.2, 0.1, 0.7, 0.1); }
解は、1つではありませんが、例を以下に示します。
13229.0 5388.0 2490.0 2242.0 778.0 396.0 196.0 186.0 186.0 186.0 186.0 29.0 29.0 8.0 8.0 8.0 0.0 add add add add p1 p1 multiply p0 p0 add p0 5.0 add p0 p0
ここで、p0がX、p1がYです。
わずか17回で収束し、解として
Y+Y+X*X+X+5+X+X
となり、 hiddenFunctionの
X**2+2*Y+3*X+5
を見つけることができました。
以下に完全なjavaのソースを添付します。
この記事は、
皆様のご意見、ご希望をお待ちしております。