Tuesday, June 7, 2011

Back Propagation

In a recent post on neural networks, using R I described neural networks and presented the following visualization from R:
 I have also described a multilayer perceptron as a weighted average or ensemble of logits. But how are the weights in each hidden layer logistic activation function (or any activation function for other network architectures) estimated? How are the weights in the combination functions estimated? Neural networks can be estimated using back propagation, described in Hastie as 'a generic approach to minimizing R(θ) (the cost function) by gradient descent.'

Given a neural network with inputs X with hidden layers comprised of hidden units Z used to predict some target T,  we can represent a neural network schematically (simplifying the notation in Hastie by omitting key subscripts and summations)

X -> Z -> T

Z = σ( α0 +  αTx)
T =  β0 + βZ
f(X) = g(T)   [1]

where σ = the activation function

Given weights {α00 ,  β0 , β} find the values that minimize the specified error function:

R(θ) =∑∑ ( y-f(x)2 )      [2] (note a number of possible error functions may be used)


Backpropogation equations:

s = σ'( αTx )βδ   [3]

Errors can be re-specified as:

∂R/ ∂β = δZ   [4]
∂R/ ∂α = sx    [5]

Gradient Descent Update:

βr+1 =  βr - γ ∂R/ ∂β    [6]

αr+1 =  αr - γ ∂R/ ∂α   [7]

Algorithm:

Forward Pass: use initial or current weights (guesses) and calculate f(X), and errors δ from the output layer [2]

Backward Pass:  'back propagate' via back propagation equation [3] to obtain s. Both sets of errors (δ) and (s) are used to derive the derivative terms in [4] and [5]  which are then used in the gradient descent update weight estimates via equations [6]& [7].

In Predictive modeling with SAS Enterprise Miner by Sarma,  the following basic description of back propagation is given:

Specify an error function E.

1) 1st iteration- set initial weights, use to evaluate E
2) 2nd iteration- weights are changed by a small amount such that the error is redced
-repeat until convergence

As Sarma explains, with each iteration a number of weights are produced, so if it takes 100 iterations to converge, 100 possible models are specified, giving 100 sets of weights. Using validation data, the best iteration can be chosen calculating E via the validation data.

No comments:

Post a Comment