最近这几年,很火,包括自己在内的很多对还是一知半解的小白也开始用深度学习做些应用。由于小白的等级不高,自己写不出来,所以就用了开源库。Deep Learning的开源库有多,如果以语言来划分的话,就有系列的tensowflow,theano,keras,C/C++系列的Caffe,还有Lua系列的torch等等。但咱们公司是用为主,大部分项目最终也是做成一个Web的服务,所以我最终选择了Deeplearning4j。
Deeplearning4j是国外创业公司Skymind的产品。目前最新的版本更新到了0.7.2。源码全部公开并托管在github上(https://github.com/deeplearning4j/deeplearning4j)。从这个库的名字上可以看出,它就是转为Java程序员写的Deep Learning库。其实这个库吸引人的地方不仅仅在于它支持Java,更为重要的是它可以支持。由于Deep Learning模型的训练需要大量的内存,而且原始数据的存储有时候也需要很大的外存空间,所以如果可以利用集群来处理便是最好不过了。当然,除了Deeplearning4j以外,还有一些Deep Learning的库可以支持Spark,比如yahoo/CaffeOnSpark,AMPLab/SparkNet以及Intel最近开源的BigDL。这些库我自己都没怎么用过,所以就不多说了,这里重点说说Deeplearning4j的使用。
一般开始使用别人的代码库,都会先跑一些demo,或者说Hello World的例子,就好像学习一门编程语言一样,第一行代码都是打印Hello World。Deep Learning的Hello World的例子一般是两个,一个是Mnist数据集的分类,另一个就是Word2Vec找相似词。由于Word2Vec并不是严格意义上的深度神经网络,因此这里就用Lenet网络处理Mnist数据集来作为Deep Learning的Hello World。Mnist是开源的28x28的黑白手写体数字图片集(http://yann.lecun.com/exdb/mnist/),其中包含6W张训练图片和1W张图片。至于Lenet的相关结构描述,可以参考这个链接:http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf。下面就详细讲述下,利用Deeplearning4j如何进行建模、训练和预测评估。
首先,我们建立一个maven项目。然后在pom文件里加入Deeplearning4j的一些相关依赖。最主要的有三个:deeplearning4j-core,datavec,nd4j。deeplearning4j-core是神经网络结构实现的代码,nd4j是用于做张量运算的库,通过JavaCPP来调用编译好的C++库(可选:ATAL, MKL, 和OpenBLAS),datavec则主要负责数据的ETL。具体可见代码:
UTF-8 0.7.1 0.7.1 0.7.1 2.10 org.nd4j nd4j-native ${nd4j.version} org.deeplearning4j dl4j-spark_2.11 ${dl4j.version} org.datavec datavec-spark_${scala.binary.version} ${datavec.version} org.deeplearning4j deeplearning4j-core ${dl4j.version}
- 这些依赖里面有和Spark相关的,主要是跑Spark要用到。不过没有关系,先引进来即可。
int nChannels = 1; //black & white picture, 3 if color image int outputNum = 10; //number of classification int batchSize = 64; //mini batch size for sgd int nEpochs = 10; //total rounds of training int iterations = 1; //number of iteration in each traning round int seed = 123; //random seed for initialize weights log.info("Load data...."); DataSetIterator mnistTrain = null; DataSetIterator mnistTest = null; mnistTrain = new MnistDataSetIterator(batchSize, true, 12345); mnistTest = new MnistDataSetIterator(batchSize, false, 12345);
MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder() .seed(seed) .iterations(iterations) .regularization(true).l2(0.0005) .learningRate(0.01)//.biasLearningRate(0.02) //.learningRateDecayPolicy(LearningRatePolicy.Inverse).lrPolicyDecayRate(0.001).lrPolicyPower(0.75) .weightInit(WeightInit.XAVIER) .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) .updater(Updater.NESTEROVS).momentum(0.9) .list() .layer(0, new ConvolutionLayer.Builder(5, 5) //nIn and nOut specify depth. nIn here is the nChannels and nOut is the number of filters to be applied .nIn(nChannels) .stride(1, 1) .nOut(20) .activation("identity") .build()) .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) .layer(2, new ConvolutionLayer.Builder(5, 5) //Note that nIn need not be specified in later layers .stride(1, 1) .nOut(50) .activation("identity") .build()) .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) .kernelSize(2,2) .stride(2,2) .build()) .layer(4, new DenseLayer.Builder().activation("relu") .nOut(500).build()) .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) .nOut(outputNum) .activation("softmax") .build()) .backprop(true).pretrain(false) .cnnInputSize(28, 28, 1); // The builder needs the dimensions of the image along with the number of channels. these are 28x28 images in one channel //new ConvolutionLayerSetup(builder,28,28,1); MultiLayerConfiguration conf = builder.build(); MultiLayerNetwork model = new MultiLayerNetwork(conf); model.init(); model.setListeners(new ScoreIterationListener(1)); // a listener which can print loss function score after each iteration
for( int i = 0; i < nEpochs; ++i ) { model.fit(mnistTrain); log.info("*** Completed epoch " + i + "***"); log.info("Evaluate model...."); Evaluation eval = new Evaluation(outputNum); while(mnistTest.hasNext()){ DataSet ds = mnistTest.next(); INDArray output = model.output(ds.getFeatureMatrix(), false); eval.eval(ds.getLabels(), output); } log.info(eval.stats()); mnistTest.reset(); }
Examples labeled as 0 classified by model as 0: 974 timesExamples labeled as 0 classified by model as 6: 2 timesExamples labeled as 0 classified by model as 7: 2 timesExamples labeled as 0 classified by model as 8: 1 timesExamples labeled as 0 classified by model as 9: 1 timesExamples labeled as 1 classified by model as 0: 1 timesExamples labeled as 1 classified by model as 1: 1128 timesExamples labeled as 1 classified by model as 2: 1 timesExamples labeled as 1 classified by model as 3: 2 timesExamples labeled as 1 classified by model as 5: 1 timesExamples labeled as 1 classified by model as 6: 2 timesExamples labeled as 2 classified by model as 2: 1026 timesExamples labeled as 2 classified by model as 4: 1 timesExamples labeled as 2 classified by model as 6: 1 timesExamples labeled as 2 classified by model as 7: 3 timesExamples labeled as 2 classified by model as 8: 1 timesExamples labeled as 3 classified by model as 0: 1 timesExamples labeled as 3 classified by model as 1: 1 timesExamples labeled as 3 classified by model as 2: 1 timesExamples labeled as 3 classified by model as 3: 998 timesExamples labeled as 3 classified by model as 5: 3 timesExamples labeled as 3 classified by model as 7: 1 timesExamples labeled as 3 classified by model as 8: 4 timesExamples labeled as 3 classified by model as 9: 1 timesExamples labeled as 4 classified by model as 2: 1 timesExamples labeled as 4 classified by model as 4: 973 timesExamples labeled as 4 classified by model as 6: 2 timesExamples labeled as 4 classified by model as 7: 1 timesExamples labeled as 4 classified by model as 9: 5 timesExamples labeled as 5 classified by model as 0: 2 timesExamples labeled as 5 classified by model as 3: 4 timesExamples labeled as 5 classified by model as 5: 882 timesExamples labeled as 5 classified by model as 6: 1 timesExamples labeled as 5 classified by model as 7: 1 timesExamples labeled as 5 classified by model as 8: 2 timesExamples labeled as 6 classified by model as 0: 4 timesExamples labeled as 6 classified by model as 1: 2 timesExamples labeled as 6 classified by model as 4: 1 timesExamples labeled as 6 classified by model as 5: 4 timesExamples labeled as 6 classified by model as 6: 945 timesExamples labeled as 6 classified by model as 8: 2 timesExamples labeled as 7 classified by model as 1: 5 timesExamples labeled as 7 classified by model as 2: 3 timesExamples labeled as 7 classified by model as 3: 1 timesExamples labeled as 7 classified by model as 7: 1016 timesExamples labeled as 7 classified by model as 8: 1 timesExamples labeled as 7 classified by model as 9: 2 timesExamples labeled as 8 classified by model as 0: 1 timesExamples labeled as 8 classified by model as 3: 1 timesExamples labeled as 8 classified by model as 5: 2 timesExamples labeled as 8 classified by model as 7: 2 timesExamples labeled as 8 classified by model as 8: 966 timesExamples labeled as 8 classified by model as 9: 2 timesExamples labeled as 9 classified by model as 3: 1 timesExamples labeled as 9 classified by model as 4: 2 timesExamples labeled as 9 classified by model as 5: 4 timesExamples labeled as 9 classified by model as 6: 1 timesExamples labeled as 9 classified by model as 7: 5 timesExamples labeled as 9 classified by model as 8: 3 timesExamples labeled as 9 classified by model as 9: 993 times==========================Scores======================================== Accuracy: 0.9901 Precision: 0.99 Recall: 0.99 F1 Score: 0.99========================================================================[main] INFO cv.LenetMnistExample - ****************Example finished********************