学习来源:日撸 Java 三百行(81-90天,CNN 卷积神经网络)_闵帆的博客-CSDN博客
文章目录- 前言
- 卷积神经网络 (代码篇)
- 一、数据集读取与存储
- 1. 数据集描述
- 2. 具体代码
- 3. 运行截图
- 二、卷积核大小的基本操作
- 1. 操作
- 2. 具体代码
- 3. 运行截图
- 三、数学工具类
- 1. 工具函数
- 2. 具体代码
- 四、网络结构与参数
- 五、神经网络的搭建
- 1. 正向传播
- 2. 反向传播
- 4. 具体代码
- 5. 运行截图
- 总结
本文代码来自 CSDN文章: 日撸 Java 三百行(81-90天,CNN 卷积神经网络)
我将借用这部分代码对 CNN 进行一个更深层次的理解.
卷积神经网络 (代码篇) 一、数据集读取与存储 1. 数据集描述简要描述一下我们需要读取的数据集.
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0
乍一看这不就是由 0 和 1组成的集合吗? 这个时候我们对这些数字想象成一个图片, 然后通过一些工具就可以呈现出下面的这样一副图片.
这张图片的大小就为 28 × 28 28 times 28 28×28, 那这堆数据最后不是多出了一个数字吗? 这个数字要表达什么意思呢? 这个时候仔细观察图片, 它是不是看起来像数字 ‘0’. 为了检验这个想法是否正确, 我们再找一行数据.
0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,0,0,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3
虽然图中的数字写法不标准, 但是隐约中还是能判别为数字 ‘3’, 然后多出的那个数字正好是 ‘3’. 由此得出结论, 数据集的每一行代表一张图片, 由 ‘0’ ‘1’ 表示其黑白像素点, 且该行最后一个数字表示图片中数字的值.
所以对于这个数据集数据的读取就是把图片的像素点以数组方式存储, 数组的大小就是图片的大小. 然后用一个单独的值存储图片中所表示的数字, 把这个就作为图片的标签.
2. 具体代码package cnn; import java.io.BufferedReader; import java.io.File; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; public class Dataset { private List3. 运行截图 二、卷积核大小的基本操作 1. 操作instances; private int labelIndex; private double maxLabel = -1; public Dataset() { labelIndex = -1; instances = new ArrayList<>(); }// Of the first constructor public Dataset(String paraFilename, String paraSplitSign, int paraLabelIndex) { instances = new ArrayList<>(); labelIndex = paraLabelIndex; File tempFile = new File(paraFilename); try { BufferedReader tempReader = new BufferedReader(new FileReader(tempFile)); String tempLine; while ((tempLine = tempReader.readLine()) != null) { String[] tempDatum = tempLine.split(paraSplitSign); if (tempDatum.length == 0) { continue; } // Of if double[] tempData = new double[tempDatum.length]; for (int i = 0; i < tempDatum.length; i++) tempData[i] = Double.parseDouble(tempDatum[i]); Instance tempInstance = new Instance(tempData); append(tempInstance); } // Of while tempReader.close(); } catch (IOException e) { e.printStackTrace(); System.out.println("Unable to load " + paraFilename); System.exit(0); }//Of try }// Of the second constructor public void append(Instance paraInstance) { instances.add(paraInstance); }// Of append public void append(double[] paraAttributes, Double paraLabel) { instances.add(new Instance(paraAttributes, paraLabel)); }// Of append public Instance getInstance(int paraIndex) { return instances.get(paraIndex); }// Of getInstance public int size() { return instances.size(); }// Of size public double[] getAttributes(int paraIndex) { return instances.get(paraIndex).getAttributes(); }// Of getAttrs public Double getLabel(int paraIndex) { return instances.get(paraIndex).getLabel(); }// Of getLabel public static void main(String[] args) { Dataset tempData = new Dataset("D:/Work/Data/sampledata/train.format", ",", 784); Instance tempInstance = tempData.getInstance(0); System.out.println("The first instance is: " + tempInstance); System.out.println("The first instance label is: " + tempInstance.label); tempInstance = tempData.getInstance(1); System.out.println("The second instance is: " + tempInstance); System.out.println("The second instance label is: " + tempInstance.label); }// Of main public class Instance { private double[] attributes; private Double label; private Instance(double[] paraAttrs, Double paraLabel) { attributes = paraAttrs; label = paraLabel; }//Of the first constructor public Instance(double[] paraData) { if (labelIndex == -1) { // No label attributes = paraData; } else { label = paraData[labelIndex]; if (label > maxLabel) { // It is a new label maxLabel = label; } // Of if if (labelIndex == 0) { // The first column is the label attributes = Arrays.copyOfRange(paraData, 1, paraData.length); } else { // The last column is the label attributes = Arrays.copyOfRange(paraData, 0, paraData.length - 1); } // Of if } // Of if }// Of the second constructor public double[] getAttributes() { return attributes; }// Of getAttributes public Double getLabel() { if (labelIndex == -1) return null; return label; }// Of getLabel public String toString() { return Arrays.toString(attributes) + ", " + label; }//Of toString }// Of class Instance } //Of class Dataset
对卷积核大小进行处理, 也就是对卷积核的长和宽进行处理.
一个方法是长和宽同时除以两个整数, 要是不能被整除就抛出错误. 例如:
(4, 12) / (2, 3) -> (2, 4) (2, 2) / (4, 6) -> Error
另一个方法是长和宽同时减去两个整数, 然后再加上 1. 例如:
(4, 6) - (2, 2) + 1 -> (3,5)2. 具体代码
package cnn; public class Size { public final int width; public final int height; public Size(int paraWidth, int paraHeight) { width = paraWidth; height = paraHeight; }// Of the first constructor public Size divide(Size paraScaleSize) { int resultWidth = width / paraScaleSize.width; int resultHeight = height / paraScaleSize.height; if (resultWidth * paraScaleSize.width != width || resultHeight * paraScaleSize.height != height) { throw new RuntimeException("Unable to divide " + this + " with " + paraScaleSize); } return new Size(resultWidth, resultHeight); }// Of divide public Size subtract(Size paraScaleSize, int paraAppend) { int resultWidth = width - paraScaleSize.width + paraAppend; int resultHeight = height - paraScaleSize.height + paraAppend; return new Size(resultWidth, resultHeight); }// Of subtract public String toString() { String resultString = "(" + width + ", " + height + ")"; return resultString; }// Of toString public static void main(String[] args) { Size tempSize1 = new Size(4, 6); Size tempSize2 = new Size(2, 2); System.out.println("" + tempSize1 + " divide " + tempSize2 + " = " + tempSize1.divide(tempSize2)); try { System.out.println("" + tempSize2 + " divide " + tempSize1 + " = " + tempSize2.divide(tempSize1)); } catch (Exception ee) { System.out.println("Error is :" + ee); } // Of try System.out.println("" + tempSize1 + " - " + tempSize2 + " + 1 = " + tempSize1.subtract(tempSize2, 1)); }// Of main } //Of class Size3. 运行截图 三、数学工具类 1. 工具函数
定义了一个算子, 其主要目的是为了矩阵操作时对每个元素都做一遍. 有对单个矩阵进行运算, 例如用 1 减去矩阵中的值, 或者对矩阵中的值使用 S i g m o i d Sigmoid Sigmoid 函数. 有对两个矩阵进行运算, 例如两个矩阵之间的加法还有减法.
矩阵旋转 180 度, 其实就是旋转两次 90 度. 旋转 90 度的公式为
m
a
t
r
i
x
[
r
o
w
]
[
c
o
l
]
=
r
o
t
a
t
e
m
a
t
r
i
x
n
e
w
[
c
o
l
]
[
n
−
r
o
w
−
1
]
matrix[row][col] overset{rotate}{=}matrix_{new}[col][n - row - 1]
matrix[row][col]=rotatematrixnew[col][n−row−1]
convnValid 是卷积操作. convnFull 为其逆向操作.
scaleMatrix 是均值池化. kronecker 是池化的逆向操作.
2. 具体代码package cnn; import java.io.Serializable; import java.util.Arrays; import java.util.HashSet; import java.util.Random; import java.util.Set; public class MathUtils { public interface Operator extends Serializable { double process(double value); }// Of interface Operator public static final Operator one_value = new Operator() { private static final long serialVersionUID = 3752139491940330714L; @Override public double process(double value) { return 1 - value; }// Of process }; public static final Operator sigmoid = new Operator() { private static final long serialVersionUID = -1952718905019847589L; @Override public double process(double value) { return 1 / (1 + Math.pow(Math.E, -value)); }// Of process }; interface OperatorOnTwo extends Serializable { double process(double a, double b); }// Of interface OperatorOnTwo public static final OperatorOnTwo plus = new OperatorOnTwo() { private static final long serialVersionUID = -6298144029766839945L; @Override public double process(double a, double b) { return a + b; }// Of process }; public static OperatorOnTwo multiply = new OperatorOnTwo() { private static final long serialVersionUID = -7053767821858820698L; @Override public double process(double a, double b) { return a * b; }// Of process }; public static OperatorOnTwo minus = new OperatorOnTwo() { private static final long serialVersionUID = 7346065545555093912L; @Override public double process(double a, double b) { return a - b; }// Of process }; public static void printMatrix(double[][] matrix) { for (double[] array : matrix) { String line = Arrays.toString(array); line = line.replaceAll(", ", "t"); System.out.println(line); } // Of for i System.out.println(); }// Of printMatrix public static double[][] cloneMatrix(final double[][] matrix) { final int m = matrix.length; int n = matrix[0].length; final double[][] outMatrix = new double[m][n]; for (int i = 0; i < m; i++) { System.arraycopy(matrix[i], 0, outMatrix[i], 0, n); } // Of for i return outMatrix; }// Of cloneMatrix public static double[][] rot180(double[][] matrix) { matrix = cloneMatrix(matrix); int m = matrix.length; int n = matrix[0].length; for (int i = 0; i < m; i++) { for (int j = 0; j < n / 2; j++) { double tmp = matrix[i][j]; matrix[i][j] = matrix[i][n - 1 - j]; matrix[i][n - 1 - j] = tmp; } } for (int j = 0; j < n; j++) { for (int i = 0; i < m / 2; i++) { double tmp = matrix[i][j]; matrix[i][j] = matrix[m - 1 - i][j]; matrix[m - 1 - i][j] = tmp; } } return matrix; }// Of rot180 private static final Random myRandom = new Random(2); public static double[][] randomMatrix(int x, int y, boolean b) { double[][] matrix = new double[x][y]; // int tag = 1; for (int i = 0; i < x; i++) { for (int j = 0; j < y; j++) { matrix[i][j] = (myRandom.nextDouble() - 0.05) / 10; } // Of for j } // Of for i return matrix; }// Of randomMatrix public static double[] randomArray(int len) { double[] data = new double[len]; for (int i = 0; i < len; i++) { //data[i] = myRandom.nextDouble() / 10 - 0.05; data[i] = 0; } // Of for i return data; }// Of randomArray public static int[] randomPerm(int size, int batchSize) { Setset = new HashSet<>(); while (set.size() < batchSize) { set.add(myRandom.nextInt(size)); } int[] randPerm = new int[batchSize]; int i = 0; for (Integer value : set) { randPerm[i++] = value; } return randPerm; }// Of randomPerm public static double[][] matrixOp(final double[][] ma, Operator operator) { final int m = ma.length; int n = ma[0].length; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { ma[i][j] = operator.process(ma[i][j]); } // Of for j } // Of for i return ma; }// Of matrixOp public static double[][] matrixOp(final double[][] ma, final double[][] mb, final Operator operatorA, final Operator operatorB, OperatorOnTwo operator) { final int m = ma.length; int n = ma[0].length; if (m != mb.length || n != mb[0].length) throw new RuntimeException("ma.length:" + ma.length + " mb.length:" + mb.length); for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { double a = ma[i][j]; if (operatorA != null) { a = operatorA.process(a); } double b = mb[i][j]; if (operatorB != null) { b = operatorB.process(b); } mb[i][j] = operator.process(a, b); } // Of for j } // Of for i return mb; }// Of matrixOp public static double[][] kronecker(final double[][] matrix, final Size scale) { final int m = matrix.length; int n = matrix[0].length; final double[][] outMatrix = new double[m * scale.width][n * scale.height]; for (int i = 0; i < m; i++) { for (int j = 0; j < n; j++) { for (int ki = i * scale.width; ki < (i + 1) * scale.width; ki++) { for (int kj = j * scale.height; kj < (j + 1) * scale.height; kj++) { outMatrix[ki][kj] = matrix[i][j]; } } } } return outMatrix; }// Of kronecker public static double[][] scaleMatrix(final double[][] matrix, final Size scale) { int m = matrix.length; int n = matrix[0].length; final int sm = m / scale.width; final int sn = n / scale.height; final double[][] outMatrix = new double[sm][sn]; if (sm * scale.width != m || sn * scale.height != n) throw new RuntimeException("scale matrix"); final int size = scale.width * scale.height; for (int i = 0; i < sm; i++) { for (int j = 0; j < sn; j++) { double sum = 0.0; for (int si = i * scale.width; si < (i + 1) * scale.width; si++) { for (int sj = j * scale.height; sj < (j + 1) * scale.height; sj++) { sum += matrix[si][sj]; } // Of for sj } // Of for si outMatrix[i][j] = sum / size; } // Of for j } // Of for i return outMatrix; }// Of scaleMatrix public static double[][] convnFull(double[][] matrix, final double[][] kernel) { int m = matrix.length; int n = matrix[0].length; final int km = kernel.length; final int kn = kernel[0].length; final double[][] extendMatrix = new double[m + 2 * (km - 1)][n + 2 * (kn - 1)]; for (int i = 0; i < m; i++) { System.arraycopy(matrix[i], 0, extendMatrix[i + km - 1], kn - 1, n); } // Of for i return convnValid(extendMatrix, kernel); }// Of convnFull public static double[][] convnValid(final double[][] matrix, double[][] kernel) { int m = matrix.length; int n = matrix[0].length; final int km = kernel.length; final int kn = kernel[0].length; int kns = n - kn + 1; final int kms = m - km + 1; final double[][] outMatrix = new double[kms][kns]; for (int i = 0; i < kms; i++) { for (int j = 0; j < kns; j++) { double sum = 0.0; for (int ki = 0; ki < km; ki++) { for (int kj = 0; kj < kn; kj++) sum += matrix[i + ki][j + kj] * kernel[ki][kj]; } outMatrix[i][j] = sum; } } return outMatrix; }// Of convnValid public static double[][] convnValid(final double[][][][] matrix, int mapNoX, double[][][][] kernel, int mapNoY) { int m = matrix.length; int n = matrix[0][mapNoX].length; int h = matrix[0][mapNoX][0].length; int km = kernel.length; int kn = kernel[0][mapNoY].length; int kh = kernel[0][mapNoY][0].length; int kms = m - km + 1; int kns = n - kn + 1; int khs = h - kh + 1; if (matrix.length != kernel.length) throw new RuntimeException("length"); final double[][][] outMatrix = new double[kms][kns][khs]; for (int i = 0; i < kms; i++) { for (int j = 0; j < kns; j++) for (int k = 0; k < khs; k++) { double sum = 0.0; for (int ki = 0; ki < km; ki++) { for (int kj = 0; kj < kn; kj++) for (int kk = 0; kk < kh; kk++) { sum += matrix[i + ki][mapNoX][j + kj][k + kk] * kernel[ki][mapNoY][kj][kk]; } } outMatrix[i][j][k] = sum; } } return outMatrix[0]; }// Of convnValid public static double sigmoid(double x) { return 1 / (1 + Math.pow(Math.E, -x)); }// Of sigmoid public static double sum(double[][] error) { int n = error[0].length; double sum = 0.0; for (double[] array : error) { for (int i = 0; i < n; i++) { sum += array[i]; } } return sum; }// Of sum public static double[][] sum(double[][][][] errors, int j) { int m = errors[0][j].length; int n = errors[0][j][0].length; double[][] result = new double[m][n]; for (int mi = 0; mi < m; mi++) { for (int nj = 0; nj < n; nj++) { double sum = 0; for (double[][][] error : errors) { sum += error[j][mi][nj]; } result[mi][nj] = sum; } } return result; }// Of sum public static int getMaxIndex(double[] out) { double max = out[0]; int index = 0; for (int i = 1; i < out.length; i++) if (out[i] > max) { max = out[i]; index = i; } return index; }// Of getMaxIndex } //Of class MathUtils
这里定义了一个枚举类用来标识每一层的属性, 比如输入层, 卷积层等.
package cnn; public enum LayerTypeEnum { INPUT, CONVOLUTION, SAMPLING, OUTPUT; } //Of enum LayerTypeEnum四、网络结构与参数
对单层设置一些工具类的函数, 然后就是通过上面的枚举类型 LayerTypeEnum 来区别神经网络中不同的层, 例如输入层、卷积层和池化层.
package cnn; public class CnnLayer { LayerTypeEnum type; int outMapNum; Size mapSize; Size kernelSize; Size scaleSize; int classNum = -1; private double[][][][] kernel; private double[] bias; private double[][][][] outMaps; private double[][][][] errors; private static int recordInBatch = 0; public CnnLayer(LayerTypeEnum paraType, int paraNum, Size paraSize) { type = paraType; switch (type) { case INPUT: outMapNum = 1; mapSize = paraSize; // No deep copy. break; case CONVOLUTION: outMapNum = paraNum; kernelSize = paraSize; break; case SAMPLING: scaleSize = paraSize; break; case OUTPUT: classNum = paraNum; mapSize = new Size(1, 1); outMapNum = classNum; break; default: System.out.println("Internal error occurred in AbstractLayer.java constructor."); }// Of switch }// Of the first constructor public void initKernel(int paraFrontMapNum) { kernel = new double[paraFrontMapNum][outMapNum][][]; for (int i = 0; i < paraFrontMapNum; i++) { for (int j = 0; j < outMapNum; j++) { kernel[i][j] = MathUtils.randomMatrix(kernelSize.width, kernelSize.height, true); } // Of for j } // Of for i }// Of initKernel public void initOutputKernel(int paraFrontMapNum, Size paraSize) { kernelSize = paraSize; initKernel(paraFrontMapNum); }// Of initOutputKernel public void initBias() { bias = MathUtils.randomArray(outMapNum); }// Of initBias public void initErrors(int paraBatchSize) { errors = new double[paraBatchSize][outMapNum][mapSize.width][mapSize.height]; }// Of initErrors public void initOutMaps(int paraBatchSize) { outMaps = new double[paraBatchSize][outMapNum][mapSize.width][mapSize.height]; }// Of initOutMaps public static void prepareForNewBatch() { recordInBatch = 0; }// Of prepareForNewBatch public static void prepareForNewRecord() { recordInBatch++; }// Of prepareForNewRecord public void setMapValue(int paraMapNo, int paraX, int paraY, double paraValue) { outMaps[recordInBatch][paraMapNo][paraX][paraY] = paraValue; }// Of setMapValue public void setMapValue(int paraMapNo, double[][] paraOutMatrix) { outMaps[recordInBatch][paraMapNo] = paraOutMatrix; }// Of setMapValue public Size getMapSize() { return mapSize; }// Of getMapSize public void setMapSize(Size paraMapSize) { mapSize = paraMapSize; }// Of setMapSize public LayerTypeEnum getType() { return type; }// Of getType public int getOutMapNum() { return outMapNum; }// Of getOutMapNum public void setOutMapNum(int paraOutMapNum) { outMapNum = paraOutMapNum; }// Of setOutMapNum public Size getKernelSize() { return kernelSize; }// Of getKernelSize public Size getScaleSize() { return scaleSize; }// Of getScaleSize public double[][] getMap(int paraIndex) { return outMaps[recordInBatch][paraIndex]; }// Of getMap public double[][] getKernel(int paraFrontMap, int paraOutMap) { return kernel[paraFrontMap][paraOutMap]; }// Of getKernel public void setError(int paraMapNo, int paraMapX, int paraMapY, double paraValue) { errors[recordInBatch][paraMapNo][paraMapX][paraMapY] = paraValue; }// Of setError public void setError(int paraMapNo, double[][] paraMatrix) { errors[recordInBatch][paraMapNo] = paraMatrix; }// Of setError public double[][] getError(int paraMapNo) { return errors[recordInBatch][paraMapNo]; }// Of getError public double[][][][] getErrors() { return errors; }// Of getErrors public void setKernel(int paraLastMapNo, int paraMapNo, double[][] paraKernel) { kernel[paraLastMapNo][paraMapNo] = paraKernel; }// Of setKernel public double getBias(int paraMapNo) { return bias[paraMapNo]; }// Of getBias public void setBias(int paraMapNo, double paraValue) { bias[paraMapNo] = paraValue; }// Of setBias public double[][][][] getMaps() { return outMaps; }// Of getMaps public double[][] getError(int paraRecordId, int paraMapNo) { return errors[paraRecordId][paraMapNo]; }// Of getError public double[][] getMap(int paraRecordId, int paraMapNo) { return outMaps[paraRecordId][paraMapNo]; }// Of getMap public int getClassNum() { return classNum; }// Of getClassNum public double[][][][] getKernel() { return kernel; } // Of getKernel }//Of class CnnLayer
在 CnnLayer 类上再封装一层, 用于更加简便地创建神经网络中的各层.
package cnn; import java.util.ArrayList; import java.util.List; public class LayerBuilder { private List五、神经网络的搭建 1. 正向传播layers; public LayerBuilder() { layers = new ArrayList<>(); }// Of the first constructor public LayerBuilder(CnnLayer paraLayer) { this(); layers.add(paraLayer); }// Of the second constructor public void addLayer(CnnLayer paraLayer) { layers.add(paraLayer); }// Of addLayer public CnnLayer getLayer(int paraIndex) throws RuntimeException { if (paraIndex >= layers.size()) { throw new RuntimeException("CnnLayer " + paraIndex + " is out of range: " + layers.size() + "."); }//Of if return layers.get(paraIndex); }//Of getLayer public CnnLayer getOutputLayer() { return layers.get(layers.size() - 1); }//Of getOutputLayer public int getNumLayers() { return layers.size(); }//Of getNumLayers } //Of class LayerBuilder
正向传播的基本内容在之前已经提到了, 这里简述一下.
一张图片通过卷积核得到特征图, 然后特征图通过自己选择的池化层进行池化, 最后使用激活函数对池化层进行激活, 并把激活后的输出做为下一个卷积层的输入.
在重复卷积、池化、激活后进入全连接层. 全连接层中也有一个卷积过程, 他是把 m × n m times n m×n 的特征图转换为 1 × n 1 times n 1×n 的向量, 然后这个向量通过 S o f t m a x Softmax Softmax 函数进行处理并归一化. 这时候这个向量中最大值的下标就表示是最有可能的类别的下标.
2. 反向传播反向传播这是一个老生常谈的问题了, 因为开始的卷积核是随机的, 所以就需要利用损失函数找到最佳的卷积核.
反向传播最开始更新的是全连接层, 它的反向传播和 ANN 网络中类似, 就是更新其中的权值.
然后就是池化层, 池化层的权值更新是最简单的. 以最大池化举例, 假设池化后的值为 6, 反向传播得到的误差为 +1, 反向传播回去得到池化前的值就是 6 + 1 = 7 6 + 1 = 7 6+1=7.
最麻烦的就是卷积层, 其中的公式推导我还是没有太弄清楚. 大致理解就是从二维出发得到了一个二维的公式, 然后将二维推广到神经网络中的多维.
知乎文章: 卷积神经网络(CNN)反向传播算法推导 有详细的推导和解释.
4. 具体代码package cnn; import java.util.Arrays; import cnn.Dataset.Instance; import cnn.MathUtils.Operator; public class FullCnn { private static double ALPHA = 0.85; public static double LAMBDA = 0; private static LayerBuilder layerBuilder; private int batchSize; private Operator divideBatchSize; private Operator multiplyAlpha; private Operator multiplyLambda; public FullCnn(LayerBuilder paraLayerBuilder, int paraBatchSize) { layerBuilder = paraLayerBuilder; batchSize = paraBatchSize; setup(); initOperators(); }// Of the first constructor private void initOperators() { divideBatchSize = new Operator() { private static final long serialVersionUID = 7424011281732651055L; @Override public double process(double value) { return value / batchSize; }// Of process }; multiplyAlpha = new Operator() { private static final long serialVersionUID = 5761368499808006552L; @Override public double process(double value) { return value * ALPHA; }// Of process }; multiplyLambda = new Operator() { private static final long serialVersionUID = 4499087728362870577L; @Override public double process(double value) { return value * (1 - LAMBDA * ALPHA); }// Of process }; }// Of initOperators public void setup() { CnnLayer tempInputLayer = layerBuilder.getLayer(0); tempInputLayer.initOutMaps(batchSize); for (int i = 1; i < layerBuilder.getNumLayers(); i++) { CnnLayer tempLayer = layerBuilder.getLayer(i); CnnLayer tempFrontLayer = layerBuilder.getLayer(i - 1); int tempFrontMapNum = tempFrontLayer.getOutMapNum(); switch (tempLayer.getType()) { case INPUT: // Should not be input. Maybe an error should be thrown out. break; case CONVOLUTION: tempLayer.setMapSize( tempFrontLayer.getMapSize().subtract(tempLayer.getKernelSize(), 1)); tempLayer.initKernel(tempFrontMapNum); tempLayer.initBias(); tempLayer.initErrors(batchSize); tempLayer.initOutMaps(batchSize); break; case SAMPLING: tempLayer.setOutMapNum(tempFrontMapNum); tempLayer.setMapSize(tempFrontLayer.getMapSize().divide(tempLayer.getScaleSize())); tempLayer.initErrors(batchSize); tempLayer.initOutMaps(batchSize); break; case OUTPUT: tempLayer.initOutputKernel(tempFrontMapNum, tempFrontLayer.getMapSize()); tempLayer.initBias(); tempLayer.initErrors(batchSize); tempLayer.initOutMaps(batchSize); break; }// Of switch } // Of for i }// Of setup private void forward(Instance instance) { setInputLayerOutput(instance); for (int l = 1; l < layerBuilder.getNumLayers(); l++) { CnnLayer tempCurrentLayer = layerBuilder.getLayer(l); CnnLayer tempLastLayer = layerBuilder.getLayer(l - 1); switch (tempCurrentLayer.getType()) { case CONVOLUTION: case OUTPUT: setConvolutionOutput(tempCurrentLayer, tempLastLayer); break; case SAMPLING: setSampOutput(tempCurrentLayer, tempLastLayer); break; default: break; }// Of switch } // Of for l }// Of forward private void setInputLayerOutput(Instance paraRecord) { CnnLayer tempInputLayer = layerBuilder.getLayer(0); Size tempMapSize = tempInputLayer.getMapSize(); double[] tempAttributes = paraRecord.getAttributes(); if (tempAttributes.length != tempMapSize.width * tempMapSize.height) throw new RuntimeException("input record does not match the map size."); for (int i = 0; i < tempMapSize.width; i++) { for (int j = 0; j < tempMapSize.height; j++) { tempInputLayer.setMapValue(0, i, j, tempAttributes[tempMapSize.height * i + j]); } // Of for j } // Of for i }// Of setInputLayerOutput private void setConvolutionOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) { final int lastMapNum = paraLastLayer.getOutMapNum(); // Attention: paraLayer.getOutMapNum() may not be right. for (int j = 0; j < paraLayer.getOutMapNum(); j++) { double[][] tempSumMatrix = null; for (int i = 0; i < lastMapNum; i++) { double[][] lastMap = paraLastLayer.getMap(i); double[][] kernel = paraLayer.getKernel(i, j); if (tempSumMatrix == null) { // On the first map. tempSumMatrix = MathUtils.convnValid(lastMap, kernel); } else { // Sum up convolution maps tempSumMatrix = MathUtils.matrixOp(MathUtils.convnValid(lastMap, kernel), tempSumMatrix, null, null, MathUtils.plus); } // Of if } // Of for i // Activation. final double bias = paraLayer.getBias(j); tempSumMatrix = MathUtils.matrixOp(tempSumMatrix, new Operator() { private static final long serialVersionUID = 2469461972825890810L; @Override public double process(double value) { return MathUtils.sigmoid(value + bias); } }); paraLayer.setMapValue(j, tempSumMatrix); } // Of for j }// Of setConvolutionOutput private void setSampOutput(final CnnLayer paraLayer, final CnnLayer paraLastLayer) { // int tempLastMapNum = paraLastLayer.getOutMapNum(); // Attention: paraLayer.outMapNum may not be right. for (int i = 0; i < paraLayer.outMapNum; i++) { double[][] lastMap = paraLastLayer.getMap(i); Size scaleSize = paraLayer.getScaleSize(); double[][] sampMatrix = MathUtils.scaleMatrix(lastMap, scaleSize); paraLayer.setMapValue(i, sampMatrix); } // Of for i }// Of setSampOutput public void train(Dataset paraDataset, int paraRounds) { for (int t = 0; t < paraRounds; t++) { System.out.println("Iteration: " + t); int tempNumEpochs = paraDataset.size() / batchSize; if (paraDataset.size() % batchSize != 0) tempNumEpochs++; double tempNumCorrect = 0; int tempCount = 0; for (int i = 0; i < tempNumEpochs; i++) { int[] tempRandomPerm = MathUtils.randomPerm(paraDataset.size(), batchSize); CnnLayer.prepareForNewBatch(); for (int index : tempRandomPerm) { boolean isRight = train(paraDataset.getInstance(index)); if (isRight) tempNumCorrect++; tempCount++; CnnLayer.prepareForNewRecord(); } // Of for index updateParameters(); if (i % 50 == 0) { System.out.print(".."); if (i + 50 > tempNumEpochs) System.out.println(); } } double p = 1.0 * tempNumCorrect / tempCount; if (t % 10 == 1 && p > 0.96) { ALPHA = 0.001 + ALPHA * 0.9; // logger.info("设置 alpha = {}", ALPHA); } // Of iff System.out.println("Training precision: " + p); // logger.info("计算精度: {}/{}={}.", right, count, p); } // Of for i }// Of train private boolean train(Instance paraRecord) { forward(paraRecord); boolean result = backPropagation(paraRecord); return result; }// Of train private boolean backPropagation(Instance paraRecord) { boolean result = setOutputLayerErrors(paraRecord); setHiddenLayerErrors(); return result; }// Of backPropagation private void updateParameters() { for (int l = 1; l < layerBuilder.getNumLayers(); l++) { CnnLayer layer = layerBuilder.getLayer(l); CnnLayer lastLayer = layerBuilder.getLayer(l - 1); switch (layer.getType()) { case CONVOLUTION: case OUTPUT: updateKernels(layer, lastLayer); updateBias(layer, lastLayer); break; default: break; }// Of switch } // Of for l }// Of updateParameters private void updateBias(final CnnLayer paraLayer, CnnLayer paraLastLayer) { final double[][][][] errors = paraLayer.getErrors(); // int mapNum = paraLayer.getOutMapNum(); // Attention: getOutMapNum() may not be correct. for (int j = 0; j < paraLayer.getOutMapNum(); j++) { double[][] error = MathUtils.sum(errors, j); double deltaBias = MathUtils.sum(error) / batchSize; double bias = paraLayer.getBias(j) + ALPHA * deltaBias; paraLayer.setBias(j, bias); } // Of for j }// Of updateBias private void updateKernels(final CnnLayer paraLayer, final CnnLayer paraLastLayer) { // int mapNum = paraLayer.getOutMapNum(); int tempLastMapNum = paraLastLayer.getOutMapNum(); // Attention: getOutMapNum() may not be right for (int j = 0; j < paraLayer.getOutMapNum(); j++) { for (int i = 0; i < tempLastMapNum; i++) { double[][] tempDeltaKernel = null; for (int r = 0; r < batchSize; r++) { double[][] error = paraLayer.getError(r, j); if (tempDeltaKernel == null) tempDeltaKernel = MathUtils.convnValid(paraLastLayer.getMap(r, i), error); else { tempDeltaKernel = MathUtils.matrixOp( MathUtils.convnValid(paraLastLayer.getMap(r, i), error), tempDeltaKernel, null, null, MathUtils.plus); } // Of if } // Of for r tempDeltaKernel = MathUtils.matrixOp(tempDeltaKernel, divideBatchSize); double[][] kernel = paraLayer.getKernel(i, j); tempDeltaKernel = MathUtils.matrixOp(kernel, tempDeltaKernel, multiplyLambda, multiplyAlpha, MathUtils.plus); paraLayer.setKernel(i, j, tempDeltaKernel); } // Of for i } // Of for j }// Of updateKernels private void setHiddenLayerErrors() { // System.out.println("setHiddenLayerErrors"); for (int l = layerBuilder.getNumLayers() - 2; l > 0; l--) { CnnLayer layer = layerBuilder.getLayer(l); CnnLayer nextLayer = layerBuilder.getLayer(l + 1); switch (layer.getType()) { case SAMPLING: setSamplingErrors(layer, nextLayer); break; case CONVOLUTION: setConvolutionErrors(layer, nextLayer); break; default: break; }// Of switch } // Of for l }// Of setHiddenLayerErrors private void setSamplingErrors(final CnnLayer paraLayer, final CnnLayer paraNextLayer) { // int mapNum = layer.getOutMapNum(); int tempNextMapNum = paraNextLayer.getOutMapNum(); // Attention: getOutMapNum() may not be correct for (int i = 0; i < paraLayer.getOutMapNum(); i++) { double[][] sum = null; for (int j = 0; j < tempNextMapNum; j++) { double[][] nextError = paraNextLayer.getError(j); double[][] kernel = paraNextLayer.getKernel(i, j); if (sum == null) { sum = MathUtils.convnFull(nextError, MathUtils.rot180(kernel)); } else { sum = MathUtils.matrixOp( MathUtils.convnFull(nextError, MathUtils.rot180(kernel)), sum, null, null, MathUtils.plus); } // Of if } // Of for j paraLayer.setError(i, sum); } // Of for i }// Of setSamplingErrors private void setConvolutionErrors(final CnnLayer paraLayer, final CnnLayer paraNextLayer) { // System.out.println("setConvErrors"); for (int m = 0; m < paraLayer.getOutMapNum(); m++) { Size tempScale = paraNextLayer.getScaleSize(); double[][] tempNextLayerErrors = paraNextLayer.getError(m); double[][] tempMap = paraLayer.getMap(m); double[][] tempOutMatrix = MathUtils.matrixOp(tempMap, MathUtils.cloneMatrix(tempMap), null, MathUtils.one_value, MathUtils.multiply); tempOutMatrix = MathUtils.matrixOp(tempOutMatrix, MathUtils.kronecker(tempNextLayerErrors, tempScale), null, null, MathUtils.multiply); paraLayer.setError(m, tempOutMatrix); } // Of for m }// Of setConvolutionErrors private boolean setOutputLayerErrors(Instance paraRecord) { CnnLayer tempOutputLayer = layerBuilder.getOutputLayer(); int tempMapNum = tempOutputLayer.getOutMapNum(); double[] tempTarget = new double[tempMapNum]; double[] tempOutMaps = new double[tempMapNum]; for (int m = 0; m < tempMapNum; m++) { double[][] outmap = tempOutputLayer.getMap(m); tempOutMaps[m] = outmap[0][0]; } // Of for m int tempLabel = paraRecord.getLabel().intValue(); tempTarget[tempLabel] = 1; for (int m = 0; m < tempMapNum; m++) { tempOutputLayer.setError(m, 0, 0, tempOutMaps[m] * (1 - tempOutMaps[m]) * (tempTarget[m] - tempOutMaps[m])); } // Of for m return tempLabel == MathUtils.getMaxIndex(tempOutMaps); }// Of setOutputLayerErrors public void setup(int paraBatchSize) { CnnLayer tempInputLayer = layerBuilder.getLayer(0); tempInputLayer.initOutMaps(paraBatchSize); for (int i = 1; i < layerBuilder.getNumLayers(); i++) { CnnLayer tempLayer = layerBuilder.getLayer(i); CnnLayer tempLastLayer = layerBuilder.getLayer(i - 1); int tempLastMapNum = tempLastLayer.getOutMapNum(); switch (tempLayer.getType()) { case INPUT: break; case CONVOLUTION: tempLayer.setMapSize( tempLastLayer.getMapSize().subtract(tempLayer.getKernelSize(), 1)); tempLayer.initKernel(tempLastMapNum); tempLayer.initBias(); tempLayer.initErrors(paraBatchSize); tempLayer.initOutMaps(paraBatchSize); break; case SAMPLING: tempLayer.setOutMapNum(tempLastMapNum); tempLayer.setMapSize(tempLastLayer.getMapSize().divide(tempLayer.getScaleSize())); tempLayer.initErrors(paraBatchSize); tempLayer.initOutMaps(paraBatchSize); break; case OUTPUT: tempLayer.initOutputKernel(tempLastMapNum, tempLastLayer.getMapSize()); tempLayer.initBias(); tempLayer.initErrors(paraBatchSize); tempLayer.initOutMaps(paraBatchSize); break; }// Of switch } // Of for i }// Of setup public int[] predict(Dataset paraDataset) { System.out.println("Predicting ... "); CnnLayer.prepareForNewBatch(); int[] resultPredictions = new int[paraDataset.size()]; double tempCorrect = 0.0; Instance tempRecord; for (int i = 0; i < paraDataset.size(); i++) { tempRecord = paraDataset.getInstance(i); forward(tempRecord); CnnLayer outputLayer = layerBuilder.getOutputLayer(); int tempMapNum = outputLayer.getOutMapNum(); double[] tempOut = new double[tempMapNum]; for (int m = 0; m < tempMapNum; m++) { double[][] outmap = outputLayer.getMap(m); tempOut[m] = outmap[0][0]; } // Of for m resultPredictions[i] = MathUtils.getMaxIndex(tempOut); if (resultPredictions[i] == tempRecord.getLabel().intValue()) { tempCorrect++; } // Of if } // Of for System.out.println("Accuracy: " + tempCorrect / paraDataset.size()); return resultPredictions; }// Of predict public static void main(String[] args) { LayerBuilder builder = new LayerBuilder(); // Input layer, the maps are 28*28 builder.addLayer(new CnnLayer(LayerTypeEnum.INPUT, -1, new Size(28, 28))); // Convolution output has size 24*24, 24=28+1-5 builder.addLayer(new CnnLayer(LayerTypeEnum.CONVOLUTION, 6, new Size(5, 5))); // Sampling output has size 12*12,12=24/2 builder.addLayer(new CnnLayer(LayerTypeEnum.SAMPLING, -1, new Size(2, 2))); // Convolution output has size 8*8, 8=12+1-5 builder.addLayer(new CnnLayer(LayerTypeEnum.CONVOLUTION, 12, new Size(5, 5))); // Sampling output has size4×4,4=8/2 builder.addLayer(new CnnLayer(LayerTypeEnum.SAMPLING, -1, new Size(2, 2))); // output layer, digits 0 - 9. builder.addLayer(new CnnLayer(LayerTypeEnum.OUTPUT, 10, null)); // Construct the full CNN. FullCnn tempCnn = new FullCnn(builder, 10); Dataset tempTrainingSet = new Dataset("D:/Work/Data/sampledata/train.format", ",", 784); // Train the model. tempCnn.train(tempTrainingSet, 10); // tempCnn.predict(tempTrainingSet); }// Of main }//Of class FullCnn5. 运行截图 总结
卷积神经网络理解起来容易, 但是实际编写一个框架对我来说就是非常痛苦且困难的事情.
首先是反向传播时数学公式的推导, 知道梯度下降和矩阵求导, 这些也仅仅是在单一的练习题中完成, 当实际运用时就找不到门路.
再者是代码的编写, 不管是数学的工具类还是矩阵的工具类, 尤其是矩阵旋转那部分刚开始看完全不理解.
要是我们只需要利用公式去编写代码, 想必这些问题就会迎刃而解. 偏偏天不遂人愿, 这些都是需要我真真实实去感受、去推导、去实现的.
路漫漫其修远兮, 吾将上下而求索.