Java DL4Jを使って、MNIST手書きデータを学習する




JavaでAI(ディープラーニング)するなら、DL4Jはメジャーなライブラリの1つですね。

DL4Jの入手

mavenを使って入手

pom.xmlは、こんな感じ。
DL4Jのコア、行列計算のライブラリND4Jを入手しておきましょう。
一番下のUIはお好みで。入れておくと学習状況がブラウザから見られます。

手動で入手

jarが必要であれば、Mavenでダウンロードしたものから拾いましょう。
jarは273個も必要です。
UIを抜いても172個必要ですので、手動での入手は現実的ではありません。
※2024年1月現在

DL4Jを使って、MNIST手書きデータを学習するサンプル

DL4Jで、MNIST手書きデータを学習します。
学習後、学習データを保存します。
※学習データの読み込みは次回。

実行結果

MNISTの手書きデータを学習。
学習後のテスト結果が表示されます。
この例では、98%の正答率でした。

先頭でSLF4Jが見つからない旨、エラー出力されます。
気になる方は、SLF4Jも入れてください。

サンプルの解説

先頭のMnistDataSetIteratorクラスのコンストラクタの部分が、MNISTデータの準備部分です。
第1引数は、バッチサイズ。1であれば、バッチ勾配降下法。2以上であれば、ミニバッチ勾配降下法です。
第2引数は、trueの場合、トレーニング(学習)用データを準備。falseの場合は、テストデータを準備。

2つめの部分が、ニューラルネットワークの定義です。
入力層、中間層2つ、出力層の構成です。
ここでは、各層は全結合、活性化関数はReLUにしています。

ここに記載のハイパーパラメータは適当なので、いろいろ変えてみてください。
固定なのは、入力層の28×28(=784)と出力層の10だけです。
入力は28×28サイズの画像。出力は0~9のどれかを判定したいから。ってことです。

この後は、この定義をもとに学習とテストの流れですね。
学習は時間がかかって大変なので、必ず保存しておきましょう。

ちなみに、UIのjarを入手した状態で、UIサーバ起動のくだりのコードが書いてあれば、
このURLで学習状況が確認できます。
http://localhost:9000/