JavaでAI(ディープラーニング)するなら、DL4Jはメジャーなライブラリの1つですね。
DL4Jの入手
mavenを使って入手
pom.xmlは、こんな感じ。
 DL4Jのコア、行列計算のライブラリND4Jを入手しておきましょう。
 一番下のUIはお好みで。入れておくと学習状況がブラウザから見られます。
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | <!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-core --> <dependency>     <groupId>org.deeplearning4j</groupId>     <artifactId>deeplearning4j-core</artifactId>     <version>1.0.0-M2.1</version> </dependency> <!-- https://mvnrepository.com/artifact/org.nd4j/nd4j-native-platform --> <dependency>     <groupId>org.nd4j</groupId>     <artifactId>nd4j-native-platform</artifactId>     <version>1.0.0-M2.1</version> </dependency> <!-- https://mvnrepository.com/artifact/org.deeplearning4j/deeplearning4j-ui --> <dependency>     <groupId>org.deeplearning4j</groupId>     <artifactId>deeplearning4j-ui</artifactId>     <version>1.0.0-M2.1</version> </dependency> | 
手動で入手
jarが必要であれば、Mavenでダウンロードしたものから拾いましょう。
 jarは273個も必要です。
 UIを抜いても172個必要ですので、手動での入手は現実的ではありません。
 ※2024年1月現在
DL4Jを使って、MNIST手書きデータを学習するサンプル
DL4Jで、MNIST手書きデータを学習します。
 学習後、学習データを保存します。
 ※学習データの読み込みは次回。
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | import java.io.File; import java.util.Arrays; import org.deeplearning4j.core.storage.StatsStorage; import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; import org.deeplearning4j.nn.conf.MultiLayerConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.layers.DenseLayer; import org.deeplearning4j.nn.conf.layers.OutputLayer; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; import org.deeplearning4j.nn.weights.WeightInit; import org.deeplearning4j.optimize.listeners.ScoreIterationListener; import org.deeplearning4j.ui.api.UIServer; import org.deeplearning4j.ui.model.stats.StatsListener; import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage; import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.linalg.activations.Activation; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.learning.config.Nesterovs; import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction; public class MnistTest {     public static void main(String[] args) throws Exception {         //MNISTデータを準備         DataSetIterator mnistTrain = new MnistDataSetIterator(32, true, 123);         DataSetIterator mnistTest = new MnistDataSetIterator(32, false, 123);         //ネットワークの定義         MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()                 .updater(new Nesterovs(0.05, 0.95))                 .list()                 .layer(new DenseLayer.Builder()                         .nIn(28 * 28)                         .nOut(500)                         .activation(Activation.RELU)                         .weightInit(WeightInit.NORMAL)                         .build())                 .layer(new DenseLayer.Builder()                         .nIn(500)                         .nOut(500)                         .activation(Activation.RELU)                         .weightInit(WeightInit.NORMAL)                         .build())                 .layer(new OutputLayer.Builder(LossFunction.SQUARED_LOSS)                         .nIn(500)                         .nOut(10)                         .activation(Activation.SOFTMAX)                         .weightInit(WeightInit.NORMAL)                         .build())                 .build();         //ネットワーク作成         MultiLayerNetwork network = new MultiLayerNetwork(conf);         network.init();         //UIサーバ起動         UIServer uiServer = UIServer.getInstance();         StatsStorage statsStorage = new InMemoryStatsStorage();         uiServer.attach(statsStorage);         network.setListeners(Arrays.asList(new ScoreIterationListener(1), new StatsListener(statsStorage)));         //学習         network.fit(mnistTrain, 15);         //学習データを保存         network.save(new File("C:\\work\\train.dat"));         //学習後にテストを実行。テスト結果を出力         Evaluation eval = network.evaluate(mnistTest);         System.out.println(eval.stats());     } } | 
実行結果
MNISTの手書きデータを学習。
 学習後のテスト結果が表示されます。
 この例では、98%の正答率でした。
| 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | SLF4J: Failed to load class "org.slf4j.impl.StaticLoggerBinder". SLF4J: Defaulting to no-operation (NOP) logger implementation SLF4J: See http://www.slf4j.org/codes.html#StaticLoggerBinder for further details. ========================Evaluation Metrics========================  # of classes:    10  Accuracy:        0.9828  Precision:       0.9829  Recall:          0.9826  F1 Score:        0.9827 Precision, recall & F1: macro-averaged (equally weighted avg. of 10 classes) =========================Confusion Matrix=========================     0    1    2    3    4    5    6    7    8    9 ---------------------------------------------------   970    0    2    1    0    1    1    1    3    1 | 0 = 0     0 1128    2    1    0    0    2    0    2    0 | 1 = 1     0    3 1020    1    1    0    1    3    3    0 | 2 = 2     1    1    3  996    0    0    0    2    4    3 | 3 = 3     0    0    2    1  961    0    3    2    0   13 | 4 = 4     2    0    0    9    1  870    4    1    3    2 | 5 = 5     4    2    1    1    1    4  944    0    1    0 | 6 = 6     1    6    7    3    0    0    0 1004    2    5 | 7 = 7     3    0    2    5    1    0    4    2  955    2 | 8 = 8     4    2    0    6    8    2    1    2    4  980 | 9 = 9 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times ================================================================== | 
 先頭でSLF4Jが見つからない旨、エラー出力されます。
 気になる方は、SLF4Jも入れてください。
サンプルの解説
先頭のMnistDataSetIteratorクラスのコンストラクタの部分が、MNISTデータの準備部分です。
 第1引数は、バッチサイズ。1であれば、バッチ勾配降下法。2以上であれば、ミニバッチ勾配降下法です。
 第2引数は、trueの場合、トレーニング(学習)用データを準備。falseの場合は、テストデータを準備。
2つめの部分が、ニューラルネットワークの定義です。
 入力層、中間層2つ、出力層の構成です。
 ここでは、各層は全結合、活性化関数はReLUにしています。
ここに記載のハイパーパラメータは適当なので、いろいろ変えてみてください。
 固定なのは、入力層の28×28(=784)と出力層の10だけです。
 入力は28×28サイズの画像。出力は0~9のどれかを判定したいから。ってことです。
この後は、この定義をもとに学習とテストの流れですね。
 学習は時間がかかって大変なので、必ず保存しておきましょう。
ちなみに、UIのjarを入手した状態で、UIサーバ起動のくだりのコードが書いてあれば、
 このURLで学習状況が確認できます。
 http://localhost:9000/







