JavaでAI(ディープラーニング)するなら、DL4Jはメジャーなライブラリの1つですね。
DL4Jの入手
mavenを使って入手
pom.xmlは、こんな感じ。
DL4Jのコア、行列計算のライブラリND4Jを入手しておきましょう。
一番下のUIはお好みで。入れておくと学習状況がブラウザから見られます。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | <!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-core --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-core</artifactId> <version>1.0.0-M2.1</version> </dependency> <!-- https://mvnrepository.com/artifact/org.nd4j/nd4j-native-platform --> <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> <version>1.0.0-M2.1</version> </dependency> <!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-ui --> <dependency> <groupId>org.deeplearning4j</groupId> <artifactId>deeplearning4j-ui</artifactId> <version>1.0.0-M2.1</version> </dependency> |
手動で入手
jarが必要であれば、Mavenでダウンロードしたものから拾いましょう。
jarは273個も必要です。
UIを抜いても172個必要ですので、手動での入手は現実的ではありません。
※2024年1月現在
DL4Jを使って、MNIST手書きデータを学習するサンプル
DL4Jで、MNIST手書きデータを学習します。
学習後、学習データを保存します。
※学習データの読み込みは次回。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import java.io.File; import java.util.Arrays; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; public class MnistTest { public static void main(String[] args) throws Exception { //MNISTデータを準備 DataSetIterator mnistTrain = new MnistDataSetIterator(32, true, 123); DataSetIterator mnistTest = new MnistDataSetIterator(32, false, 123); //ネットワークの定義 MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .updater(new Nesterovs(0.05, 0.95)) .list() .layer(new DenseLayer.Builder() .nIn(28 * 28) .nOut(500) .activation(Activation.RELU) .weightInit(WeightInit.NORMAL) .build()) .layer(new DenseLayer.Builder() .nIn(500) .nOut(500) .activation(Activation.RELU) .weightInit(WeightInit.NORMAL) .build()) .layer(new OutputLayer.Builder(LossFunction.SQUARED_LOSS) .nIn(500) .nOut(10) .activation(Activation.SOFTMAX) .weightInit(WeightInit.NORMAL) .build()) .build(); //ネットワーク作成 MultiLayerNetwork network = new MultiLayerNetwork(conf); network.init(); //UIサーバ起動 UIServer uiServer = UIServer.getInstance(); StatsStorage statsStorage = new InMemoryStatsStorage(); uiServer.attach(statsStorage); network.setListeners(Arrays.asList(new ScoreIterationListener(1), new StatsListener(statsStorage))); //学習 network.fit(mnistTrain, 15); //学習データを保存 network.save(new File("C:\\work\\train.dat")); //学習後にテストを実行。テスト結果を出力 Evaluation eval = network.evaluate(mnistTest); System.out.println(eval.stats()); } } |
実行結果
MNISTの手書きデータを学習。
学習後のテスト結果が表示されます。
この例では、98%の正答率でした。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder". SLF4J: Defaulting to no-operation (NOP) logger implementation SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details. ========================Evaluation Metrics======================== # of classes: 10 Accuracy: 0.9828 Precision: 0.9829 Recall: 0.9826 F1 Score: 0.9827 Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes) =========================Confusion Matrix========================= 0 1 2 3 4 5 6 7 8 9 --------------------------------------------------- 970 0 2 1 0 1 1 1 3 1 | 0 = 0 0 1128 2 1 0 0 2 0 2 0 | 1 = 1 0 3 1020 1 1 0 1 3 3 0 | 2 = 2 1 1 3 996 0 0 0 2 4 3 | 3 = 3 0 0 2 1 961 0 3 2 0 13 | 4 = 4 2 0 0 9 1 870 4 1 3 2 | 5 = 5 4 2 1 1 1 4 944 0 1 0 | 6 = 6 1 6 7 3 0 0 0 1004 2 5 | 7 = 7 3 0 2 5 1 0 4 2 955 2 | 8 = 8 4 2 0 6 8 2 1 2 4 980 | 9 = 9 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times ================================================================== |
先頭でSLF4Jが見つからない旨、エラー出力されます。
気になる方は、SLF4Jも入れてください。
サンプルの解説
先頭のMnistDataSetIteratorクラスのコンストラクタの部分が、MNISTデータの準備部分です。
第1引数は、バッチサイズ。1であれば、バッチ勾配降下法。2以上であれば、ミニバッチ勾配降下法です。
第2引数は、trueの場合、トレーニング(学習)用データを準備。falseの場合は、テストデータを準備。
2つめの部分が、ニューラルネットワークの定義です。
入力層、中間層2つ、出力層の構成です。
ここでは、各層は全結合、活性化関数はReLUにしています。
ここに記載のハイパーパラメータは適当なので、いろいろ変えてみてください。
固定なのは、入力層の28×28(=784)と出力層の10だけです。
入力は28×28サイズの画像。出力は0~9のどれかを判定したいから。ってことです。
この後は、この定義をもとに学習とテストの流れですね。
学習は時間がかかって大変なので、必ず保存しておきましょう。
ちなみに、UIのjarを入手した状態で、UIサーバ起動のくだりのコードが書いてあれば、
このURLで学習状況が確認できます。
http://localhost:9000/