Wednesday, November 29, 2017

Refactoring My Previous Deep Learning Library Ideas

Refactoring My Previous Deep Learning Library Ideas

in some previous posts I looked at how I could implement a simple deep learning library to gain a better understanding on the inner workings of: Auto Differentiation, Weight Sharing and a good Back Propagation implementation from the deep learning book.

In this post I will present my ideas on a refactored version. The main point is that the graph representation I implemented in the first two posts made it easier to build the computation graph, 
but the code to run back propagation, especially with weight sharing was quite messy. So this refactored version first creates a computation tree as before, only that the actual computations are decoupled from the tree. The tree later is translated into the instructions and parent / child relationships used in the book version of the back propagation algorithm. 

Now in the first step we write model setup code. For example, a neural network for mnist classification:

As in tensorflow I define placeholders, variables and computation nodes, along with their name.
Furthermore, the variables get a shape and an initialisation method. In the background this method constructs a computation tree rooted at a5 using the following DSL.

Each node, keeps track of the name, an instruction name and a shape, if needed.
We can then convert this tree into a neural network ready to be used in the back propagation code from original post. We build four maps keeping:
  1. Mapping from node names to computations (like add, multiply, ...)
  2. Mapping from node names to their parents node names
  3. Mapping from node names to their child node names
  4. Mapping from node names that are weights to matrices.
All computations such as add or multiply implement the forward pass through the variable 
as well as the backward pass.

Using a class I called TensorStore that maps from names to matrices, I implemented the forward and the backward pass as follows:

The code can be found along the other versions on github. An example for the spiral dataset as well as the mnist data set is attached.

Of course the code is not at all up to the popular libraries in terms of speed or functionality so for me it is more an ongoing investigation on how these libraries are implemented and what I can reproduce. :). On Mnist the given network achieves between 96 and 98 % accuracy depending on the starting condition.