これからスパイキングニューラルネットワークを始める人へ2 : ニューロンモデルの実装
スパイキングニューラルネットワークについて日本語の資料が少ないと思ったので、あんまり詳しくない人に向けて関連情報をまとめてみようかと思います。(説明の記事はいっぱいあるけど、実践的な内容が少ない)
多分、ソースコードとか調べても見つけにくいと思うので、サンプルとして最後にC++のコードもつけようと思います。
なるべく簡単にまとめようと思いますが難しかったらコメントくだせぇ。
※注意:この記事は公開しながら編集していますので、ところどころ変かもしれないですが、よろしくです。
今回は実装です。基本はC++で書きますが、他のソースコードへのリンクも載せるので多分参考になると思います。(リンクはたぶんMatlabのコードが多い)
この記事だけでも、内容がわかるようにはしてますが、一応前回からの続きです。↓
これからスパイキングニューラルネットワークを始める人へ1 : ニューロンのモデルについて - オープンメモ置き場
ケッコー見てくれてる人がいるようで更新が遅くてすみません。。。頑張ります!
追記2018/8/10
修論終わったら頑張って書きます!!
目次
数値解析
前回の記事でニューロンのモデルの式を解説して、その特性を常微分方程式で表しました。その常微分方程式をコンピューターで解くために、数値解析の知識が必要です。
なんで必要かというと、微分の定義を思い出して欲しいんですが
こんな感じだったと思います。
げげっ、無限に小さいがいます。こいつはコンピューターでは扱えません。コンピューターでは全ての変数(時間も)が一定の幅を持って進めていくしかないですから。このを無限に小さくせずに、1とか0.1とかの一定の数値で近似して計算するしかないのです。こいつを数値解析でなんとかします。
以下では、代表的な方法であるオイラー法とルンゲ=クッタ法を紹介します。
各方法の説明の前に、これらの方法が常微分方程式から何にたどり着ければゴールなのかを説明します。
ゴール
状態変数の初期値を決めます。そしてそこから時間を進めていったら状態がどう変化するかがわかればいいわけです。
ニューロンなら今何mVで、次に何mVになるのかを知りたいわけです。
これはつまり、時間の状態を、時間を進めた時間時の状態をとして、以下の式におけるを定義することができたらゴールです。
オイラー法
一番単純な方法がオイラー法で、こんなの
こいつはからの時間の間の変化率が一定であることを仮定しています。
方法としては単純で、とりあえずシミュレーションをやるにはこれを使えば何の問題もないです。
変化率が一定であると仮定してるので、一定じゃない時に誤差が大きくなります。
オイラー法 - Wikipedia
ルンゲ=クッタ法
正確には4次ルンゲ=クッタ法です。1次も2次もあるんですけど、4次が1番計算コストと正確性のバランスが良いらしく、4次しか使われないので省略されがちです。こんな感じ
証明は難しいため省略しますが、イメージはこんな感じです。
初期値を変化させて4回計算した結果を加重平均します。
式を見てもなんじゃこりゃだと思うので、プログラミングしてみて慣れてください。数値解析でかなりよく使われる常套手段です。
みんな最初は名称に疑問を持って「昨日ルンゲ、食い過ぎたわ」みたいなボケを言うという噂があります。
誤差の比較の図です。
数値解析の実装
なるべくはそのままコンパイルすれば動くプログラムを見せたいので、ちょっと長くなるかもしれないですけどすみません。
出力はcsvにします。
csvはExcelとかで読み込んで可視化してみて下さい。
できる人はRかPythonで可視化して下さい。
あんまり式の数が多いと分かりづらいのでフィッツフュー・南雲モデルを実装します。
FitzHugh-Nagumo model - Scholarpedia
そのまま書く
上で紹介した方法をそのまま書いていきます。
オイラー法
#include<iostream> #include<fstream> int main(){ //initialize const double h = 1.; // 1ステップの時間[ms] const double simulationEnd = 100.; // シミュレーション時間[ms] double v = 0.; // 状態変数 double w = 0.; double I = 4.; //output ofstream output("Euler_Result.csv"); // 出力ファイル //simulation for(double t=0; t<simulationEnd; t=t+h){ output << t << "," << v <<std::endl; v += h * (v - v*v*v/3 - w + I); w += h * (0.08 * (v + 0.7 - 0.8*w)); } output.close(); return 0; }
ルンゲ=クッタ法
#include<iostream> #include<fstream> int main(){ //initialize const double h = 1.; // 1ステップの時間[ms] const double simulationEnd = 100.; // シミュレーション時間[ms] double v = 0.; // 状態変数 double w = 0.; double I = 4.; //output ofstream output("RK4_Result.csv"); // 出力ファイル //simulation for(double t=0; t<simulationEnd; t=t+h){ output << t << "," << v <<std::endl; v_calc = v; // 初期値を変えて計算を行うので w_calc = w; // 計算用の変数を用意します k_1 = v_calc - v_calc*v_calc*v_calc/3 - w_calc + I; l_1 = 0.08 * (v_calc + 0.7 - 0.8*w_calc); v_calc = v_calc + h * k1 / 2; w_calc = w_calc + h * l1 / 2; k_2 = v_calc - v_calc*v_calc*v_calc/3 - w_calc + I; l_2 = 0.08 * (v_calc + 0.7 - 0.8*w_calc); v_calc = v_calc + h * k2 / 2; w_calc = w_calc + h * l2 / 2; k_3 = v_calc - v_calc*v_calc*v_calc/3 - w_calc + I; l_3 = 0.08 * (v_calc + 0.7 - 0.8*w_calc); v_calc = v_calc + h * k3; w_calc = w_calc + h * l3; k_4 = v_calc - v_calc*v_calc*v_calc/3 - w_calc + I; l_4 = 0.08 * (v_calc + 0.7 - 0.8*w_calc); v += h * (k_1 + 2 * k_2 + 2 * k_3 + k_4) / 6; w += h * (l_1 + 2 * l_2 + 2 * l_3 + l_4) / 6; } output.close(); return 0; }