Thursday, February 16, 2017

AI: Implement your own Automatic Differentiation Algorithm For Neural Networks



Modern deep learning libraries such as Tensor Flow provide automatic differentiation in order to
compute the gradients in each neural network layer. While the algorithm for back propagation
in feed forward networks is well known, the gradients for new or more interesting architectures
is often hard to understand. This is where automatic differentiation can help. In this post
I will describe the algorithm and provide a simple Scala implementation that might help
to understand how automatic differentiation can be implemented. I also 'cheat' in the sense that
I use a linear algebra library that already interfaces to some native libraries that in turn use Blas and LaPack so my performance is dependent on this wrapper. I guess for playing around this is sufficient.
The code can be found on Github.


Automatic Differentiation

In my opinion the easiest way to understand automatic differentiation is through a computation
graph. In a computation graph variables, operations and functions are represented as nodes in the graph while edges represent the input and output the said functions and operations.

Let us consider the example computation:


  • W * x + b 
  • with W = 2
  • with  x = 3
  • with  b = 1

This results in the following computation tree. If we want to get the value this formula, we can 
request the knowledge for the root node and then evaluate the expression recursively. The result 
can then be achieved through backtracking (here 7). This is equivalent to finding all paths to the target variable and on the way instantiate all the results needed to compute the final value.
We can then propagate derivatives backwards from the variable to the inputs. For example,
if we measure an error of 2 at the top, we can push it down the tree to all children.
Since we apply the chain rule, a derivative of a function:  f(g(x))' = f'(g(x)) * g(x)',
we multiply the gradient flowing from the top, with the local gradient at the current node.
When adding two numbers, the local gradient is 1 so the gradient from the top flows through
uninterrupted to both children. Let the gradient from the top be called delta.
When multiplying to numbers, x and w, the gradient flowing towards x and w is:

  • dx  = delta * w
  • dw = delta * x


So the inputs get swapped and multiplied with the incoming gradient. The following image shows the forward pass in green and the backward pass in red.



Before we go on, if there is a node with two outgoing edges in the forward pass, we sum all in the incoming gradients during back propagation. While this does not happen in feed forward neural networks, it happens when using weight sharing such as in convolutional neural networks or recurrent neural networks.

Now we can turn to matrix computations needed for neural networks. Adding two matrices is actually equivalent to the scalar form from before. Since the addition is an element wise operation, the gradient (now a matrix) simply flows unchanged to the children. In general element wise operations behave like their scalar counterparts.

The gradient of the dot product between two matrices x and y becomes:


  • dx  = delta * w.t
  • dw = x.t * delta

The only thing left to build a neural network is to know the gradient of the activation function.
When using sigmoid units the forward and backward pass for this node given the input x is:

  • fwd: y(x) = 1.0 / (1.0 + exp(-x))
  • bwd: (1.0 - y(x)) * (y(x))
Again this is an element wise operation so we can simply apply the formula to each element in the matrix.

A quick and dirty implementation

In order to implement this behaviour we first implement a general computation node
 
abstract class ComputationNode {

  var valueOpt: Option[DMat]    = None
  var gradientOpt: Option[DMat] = None

  def gradient = gradientOpt.get

  def value = valueOpt.get

  def assignVal(x: DMat): DMat = {
    if(valueOpt == None) {
      valueOpt = Some(x)
    }
    value
  }

  def collectGradient(x: DMat): Unit = {
    if(gradientOpt == None) gradientOpt = Some(x)
    else gradientOpt = Some(x + gradient)
  }

  def reset: Unit

  def fwd: DMat

  def bwd(error: DMat): Unit

}

The node holds a value that is computed during the forward pass and the gradient.
I implemented this so the value does not have to be recomputed.
In other words, the values and gradients are represented as Scala options set to None by default.
When the value needs to be computed first, we initialise the option to its value. When the value is needed again, we do not have to recompute it in that way. For the gradients, when the first value comes in we graph we set the value and all later values are summed.

UPDATE: If you actually want to use weight sharing it is important to have a mechanism 
to only send the backward message when all incoming gradients are received. This is not implemented here. I will discuss weight sharing in a later post.

The forward computation should compute the result up to that node and the backward function
should distribute the gradient. The reset function needs to delete all values (set back to None values)
in order to be able to reuse the graph.

Some Computation Nodes For Tensors

Now we are ready to implement the actual computation nodes. All these nodes
extend the abstract class from above. The nodes we need for a basic neural network are addition of two matrices, the inner product and the sigmoid function since one layer in a feed forward neural network is given by: sigmoid(x*W + b). Using the linear algebra library this is actually quite simple.

During the forward pass in an addition node, we simply add the forward results (2 matrices) so far and pass back the error unchanged.

case class Addition(x: ComputationNode, y: ComputationNode) extends ComputationNode {

  def fwd = assignVal(x.fwd + y.fwd)

  def bwd(error: DMat) = {
    collectGradient(error)
    x.bwd(error)
    y.bwd(error)
  }

  def reset = {
    gradientOpt = None
    valueOpt    = None
    x.reset
    y.reset
  }

}

During the forward pass in an inner product node, we compute the dot product (2 matrices) so far and pass back the error multiplied with the other child's output.

case class InnerProduct(x: ComputationNode, y: ComputationNode) extends ComputationNode {

  def fwd = {
    assignVal(x.fwd * y.fwd)
  }

  def bwd(error: DMat) = {
    collectGradient(error)
    x.bwd(error * y.fwd.t)
    y.bwd(x.fwd.t * error)
  }

  def reset = {
    gradientOpt = None
    valueOpt    = None
    x.reset
    y.reset
  }

}

The sigmoid behaves like expected. We simply pass the error flowing from the top multiplied with the derivative of the function.

case class Sigmoid(x: ComputationNode) extends ComputationNode {

  def fwd = {
    assignVal(sigmoid(x.fwd))
  }

  def bwd(error: DMat) = {
    collectGradient(error)
    val ones = DenseMatrix.ones[Double](value.rows, value.cols)
    val derivative = (ones - value) :* value
    x.bwd(error :* derivative)
  }

  def reset = {
    gradientOpt = None
    valueOpt    = None
    x.reset
  }

}

Last but not least we need a class representing a tensor such as a weight or an input.
During the forward pass we simply return the value of the tensor. And during the backward
pass we sum up all incoming gradients. There is another method called ':=' (using parameter overloading) to set a tensor.

class Tensor extends ComputationNode {

  def fwd = value

  def bwd(error: DMat) = collectGradient(error)

  def :=(x: DMat) = valueOpt = Some(x)

  def reset = {
    gradientOpt = None
  }

}

For example, using this implementation we can build a simple logistic regressor
val x = Tensor
val w = Tensor
val b = Tensor

w := DenseMatrix.rand[Double]( 100, 1)
b := DenseMatrix.zeros[Double](1,  1)
 
val logistic = sigmoid(Addition(InnerProduct(x, w), b))

x := input
val result = logistic.fwd
val error  = - (label - result)

logistic.bwd(error)

w := w.value - 0.01 * w.gradient
b := b.value - 0.01 * b.gradient
In the example, we first initialize some tensors and then define the computation graph. Setting the input we compute the result of the computation and then push the error back through the graph. In the end, we update the weights and biases. This can be seen as one step of stochastic gradient descent.

Implementing a Generic Feed Forward Layer and a small Calculation Domain

In order to have an easier job of defining computations we now implement a small domain specific language in order to create the graph. We use parameter overloading. We now define a wrapper class around computation nodes that enables us to write expresssions such as: sigmoid(Computation(x) * w + b.) The class basically holds the computation so far. If we call the '+' function, we chain the compuation so far with the new compute node using an add node as a parent and return a new computation. In this way we can also implement a wrapper for a feed forward neural network.
 
class Computation(val x: ComputationNode) {

  def +(y: ComputationNode): Computation = new Computation(Addition(x, y))

  def *(y: ComputationNode): Computation = new Computation(InnerProduct(x, y))

}

object Computation {

  def apply(x: ComputationNode): Computation = new Computation(x)

  def sigmoid(x: Computation): Computation = new Computation(Sigmoid(x.x))

}


case class FeedForwardLayerSigmoid(x: ComputationNode, w: Tensor, b: Tensor) extends ComputationNode {

  final val Logit = Computation(x) * w + b

  final val Layer = Computation.sigmoid(Logit).x

  def fwd = Layer.fwd

  def bwd(error: DMat) = Layer.bwd(error)

  def update(rate: Double): Unit = {
    w := w.value - (rate * w.gradient)
    b := b.value - (rate * b.gradient)
    if(x.isInstanceOf[FeedForwardLayerSigmoid]) x.asInstanceOf[FeedForwardLayerSigmoid].update(rate)
  }

  def reset = Layer.reset

}
The feed forward layer is itself a computation node and specifies the computation for the layer. In the forward pass we compute the result of the complete layer and on the backward pass we also propagate the error through the whole layer. I also included an update function that recursively 'learns' the weights using stochastic gradient descent. Now we can learn the network using the following program.


val nn = FeedForwardLayerSigmoid(FeedForwardLayerSigmoid(input, W1, B1), W2, B2)
// repeat for all examples 
input := x
val prediction = nn.fwd
val error      = -(correct - prediction)
nn.bwd(error)
nn.update(0.01)
nn.reset

Experiments

For the first experiment I use the synthetic spiral dataset. It is basically a non linear dataset with three classes.


I use a two layer neural network with the following initialisation

input :=  DenseMatrix.rand[Double](1,   2)
W1    := (DenseMatrix.rand[Double](2, 100) - 0.5) * 0.01
B1    :=         DenseMatrix.zeros(1, 100)
W2    := (DenseMatrix.rand[Double](100, 3) - 0.5) * 0.01
B2    :=         DenseMatrix.zeros(1,   3)

And receive a 95% - 98% accuracy. So this seems to work. Next I tried to classify mnist. Again I used a two layer neural network. The hidden layer is 100 dimensions wide.

x   := DenseMatrix.rand[Double](1,   784)
w1  := (DenseMatrix.rand[Double](784, 100) - 0.5) * 0.001
b1  := DenseMatrix.zeros[Double](1,   100)
w2  := (DenseMatrix.rand[Double](100, 10)  - 0.5) * 0.001
b2  := DenseMatrix.zeros[Double](1,   10)

I got a 97% accuracy. This is far away from the state of the art however, for a simple 2 layer neural network it is not bad. If you remove one layer I'll got an accuracy of about 92% which is the same reported in the tensorflow beginners tutorial. I also visualised the weights of the first layer.



One more thing. This library is really not optimized for speed. The big players such as theano, tensorflow, caffe and dl4j are way faster and more optimized. With this I just hope to clarify the concept.

References
  • [1]Stanford tutorial on unsupervised feature learning: http://ufldl.stanford.edu/wiki/index.php/Neural_Networks
  • [2]Stanford course on automatic differentiation: http://cs231n.github.io/optimization-2/
  • [3]Some notes on auto diff from Wisconsin: http://pages.cs.wisc.edu/~cs701-1/LectureNotes/trunk/cs701-lec-12-1-2015/cs701-lec-12-01-2015.pdf

No comments:

Post a Comment