Truncated BPTT in Deep Learning
Truncated BPTT in Deep Learning
Mitesh M. Khapra
1/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Module 14.1: Sequence Learning Problems
2/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
In feedforward and convolutional
neural networks the size of the input
was always fixed
For example, we fed fixed size (32 ×
32) images to convolutional neural
networks for image classification
10
5
10 5
3/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
In feedforward and convolutional
P (chair|sat, he)
P (man|sat, he)
neural networks the size of the input
P (on|sat, he)
P (he|sat, he)
Wcontext Wcontext
he sat
4/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
In feedforward and convolutional
neural networks the size of the input
was always fixed
apple
For example, we fed fixed size (32 ×
bus 32) images to convolutional neural
10
5 car networks for image classification
10 5
.. Similarly in word2vec, we fed a fixed
.
window (k) of words to the network
Further, each input to the network
was independent of the previous or
future inputs
5/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
In feedforward and convolutional
neural networks the size of the input
was always fixed
apple
For example, we fed fixed size (32 ×
bus 32) images to convolutional neural
10
5 car networks for image classification
10 5
.. Similarly in word2vec, we fed a fixed
.
window (k) of words to the network
Further, each input to the network
was independent of the previous or
future inputs
For example, the computatations,
outputs and decisions for two success-
ive images are completely independ-
ent of each other
6/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
In many applications the input is not
of a fixed size
Further successive inputs may not be
e e p h stop i independent of each other
For example, consider the task of
auto completion
Given the first character ‘d’ you want
to predict the next character ‘e’ and
so on
d e e p
7/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Notice a few things
First, successive inputs are no longer
independent (while predicting ‘e’ you
e e p h stop i would want to know what the previ-
ous input was in addition to the cur-
rent input)
Second, the length of the inputs and
the number of predictions you need
to make is not fixed (for example,
“learn”, “deep”, “machine” have dif-
ferent number of characters)
Third, each network (orange-blue-
d e e p
green structure) is performing the
same task (input : character output
: character)
8/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
These are known as sequence learning
problems
We need to look at a sequence of (de-
e e p h stop i pendent) inputs and produce an out-
put (or outputs)
Each input corresponds to one time
step
Let us look at some more examples of
such problems
d e e p
9/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Consider the task of predicting the part
of speech tag (noun, adverb, adjective
verb) of each word in a sentence
Once we see an adjective (social) we are
noun verb article adjective noun almost sure that the next word should be
a noun (man)
Thus the current output (noun) depends
on the current input as well as the previ-
ous input
Further the size of the input is not fixed
(sentences could have arbitrary number
of words)
Notice that here we are interested in pro-
man is a social animal ducing an output at each time step
Each network is performing the same
task (input : word, output : tag)
10/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Sometimes we may not be interested
in producing an output at every stage
Instead we would look at the full se-
don’t
care
don’t
care
don’t
care
don’t
care
don’t
care +/−
quence and then produce an output
For example, consider the task of pre-
dicting the polarity of a movie review
The prediction clearly does not de-
pend only on the last word but also
on some words which appear before
Here again we could think that the
network is performing the same task
at each step (input : word, output :
The movie was boring and long
+/−) but it’s just that we don’t care
about intermediate outputs
11/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Sequences could be composed of any-
thing (not just words)
For example, a video could be treated
as a sequence of images
Surya Namaskar We may want to look at the entire se-
quence and detect the activity being
performed
...
...
12/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Module 14.2: Recurrent Neural Networks
13/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
How do we model such tasks involving sequences ?
14/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Wishlist
Account for dependence between inputs
Account for variable number of inputs
Make sure that the function executed at each time step is the same
We will focus on each of these to arrive at a model for dealing with sequences
15/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
What is the function being executed
at each time step ?
y1 y2
si = σ(U xi + b)
yi = O(V si + c)
i = timestep
V V
Since we want the same function to be
s1 s2 executed at each timestep we should
share the same network (i.e., same
U U
parameters at each timestep)
x1 x2
16/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
This parameter sharing also ensures
that the network becomes agnostic to
the length (size) of the input
y1 y2 y3 y4 yn Since we are simply going to compute
the same function (with same para-
meters) at each timestep, the number
V V V V V of timesteps doesn’t matter
We just create multiple copies of the
s1 s2 s3 s4 . . . sn
network and execute them at each
U U U U U timestep
x1 x2 x3 x4 xn
17/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
y1 y2 How do we account for dependence
between inputs ?
v v Let us first see an infeasible way of
doing this
u u
At each timestep we will feed all the
previous inputs to the network
x1 x1 x2
Is this okay ?
No, it violates the other two items on
y3 y4
our wishlist
How ? Let us see
v v
u u
x1 x2 x3 x1 x2 x3 x4
18/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
y1 y2 First, the function being computed at
each time-step now is different
v v
y1 = f1 (x1 )
y2 = f2 (x1 , x2 )
u u
y3 = f3 (x1 , x2 , x3 )
x1 x1 x2
x1 x2 x3 x1 x2 x3 x4
19/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
The solution is to add a recurrent
connection in the network,
y1 y2 y3 y4 yn
si = σ(U xi + W si−1 + b)
yi = O(V si + c)
or
V V V V V yi = f (xi , si−1 , W, U, V, b, c)
W W W W ... W sn si is the state of the network at
timestep i
U U U U U
The parameters are W, U, V, c, b
which are shared across timesteps
x1 x2 x3 x4 xn
The same network (and parameters)
can be used to compute y1 , y2 , . . . , y10
or y100
20/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
This can be represented more com-
pactly
yi
si W
xi
21/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Let us revisit the sequence learning
e e p h stop i noun verb article adjective noun problems that we saw earlier
We now have recurrent connections
between time steps which account for
dependence between inputs
d e e p man is a social animal
Surya Namaskar
don’t don’t don’t don’t don’t
care care care care care +/−
...
...
The movie was boring and long
22/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Module 14.3: Backpropagation through time
23/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Before proceeding let us look at the
dimensions of the parameters care-
fully
y1 y2 y3 y4
xi ∈ Rn (n-dimensional input)
s i ∈ Rd (d-dimensional state)
k
V V V V yi ∈ R (say k classes)
W W W W U ∈ Rn×d
V ∈ Rd×k
U U U U
W ∈ Rd×d
x1 x2 x3 x4
24/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
How do we train this network ?
(Ans: using backpropagation)
Let us understand this with a con-
y1 y2 y3 y4 crete example
V V V V
W W W W
U U U U
x1 x2 x3 x4
25/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Suppose we consider our task of auto-
completion (predicting the next char-
acter)
For simplicity we assume that there
e e p h stop i are only 4 characters in our vocabu-
lary (d,e,p, <stop>)
At each timestep we want to predict
V V V V one of these 4 characters
W W W What is a suitable output function for
this task ? (softmax)
U U U U What is a suitable loss function for
this task ? (cross entropy)
d e e p
26/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Suppose we initialize U, V, W ran-
L1 (θ) L2 (θ) L3 (θ) L4 (θ) domly and the network predicts the
y1 y2 y3 y4
probabilities as shown
PredictedTrue Predicted
True Predicted
True Predicted
True
d 0.2 0 0.2 0 0.2 0 0.2 0
e 0.7
p 0.1
1 0.7 1 0.1 0 0.1 0 And the true probabilities are as
0 0.1 0 0.7 1 0.7 0
stop 0.1 0 0.1 0 0.1 0 0.1 1 shown
V V V V We need to answer two questions
W W W What is the total loss made by the
model ?
U U U U How do we backpropagate this loss
and update the parameters (θ =
{U, V, W, b, c}) of the network ?
d e e p
27/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
The total loss is simply the sum of the
L1 (θ) L2 (θ) L3 (θ) L4 (θ) loss over all time-steps
y1 y2 y3 y4
PredictedTrue Predicted
True Predicted
True Predicted
True T
d 0.2 0 0.2 0 0.2 0 0.2 0 X
e 0.7
p 0.1
1
0
0.7
0.1
1
0
0.1
0.7
0
1
0.1
0.7
0
1 L (θ) = Lt (θ)
stop 0.1 0 0.1 0 0.1 0 0.1 0
t=1
V V V V Lt (θ) = −log(ytc )
W W W ytc = predicted probability of true
character at time-step t
U U U U T = number of timesteps
d e e e
For backpropagation we need to com-
pute the gradients w.r.t. W, U, V, b, c
Let us see how to do that
28/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Let us consider ∂L (θ)
∂V (V is a matrix
L1 (θ) L2 (θ) L3 (θ) L4 (θ) so ideally we should write ∇v L (θ))
y1 y2 y3 y4
PredictedTrue Predicted
True Predicted
True Predicted
True T
d 0.2
e 0.7
0
1
0.2
0.7
0
1
0.2
0.1
0
0
0.2
0.1
0
0
∂L (θ) X ∂Lt (θ)
p 0.1 0 0.1 0 0.7 1 0.7 1 =
stop 0.1 0 0.1 0 0.1 0 0.1 0 ∂V ∂V
t=1
V V V V
W W W Each term is the summation is simply
the derivative of the loss w.r.t. the
U U U U weights in the output layer
We have already seen how to do this
when we studied backpropagation
d e e e
29/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
∂L (θ)
Let us consider the derivative ∂W
L1 (θ) L2 (θ) L3 (θ) L4 (θ)
y1 y2 y3 y4 T
PredictedTrue Predicted
True Predicted
True Predicted
True
∂L (θ) X ∂Lt (θ)
d 0.2 0 0.2 0 0.2 0 0.2 0 =
e 0.7
p 0.1
1 0.7 1 0.1 0 0.1 0 ∂W ∂W
0 0.1 0 0.7 1 0.7 1 t=1
stop 0.1 0 0.1 0 0.1 0 0.1 0
By the chain rule of derivatives we
V V V V know that ∂L t (θ)
∂W is obtained by sum-
W W W ming gradients along all the paths
from Lt (θ) to W
U U U U What are the paths connecting Lt (θ)
to W ?
Let us see this by considering L4 (θ)
d e e e
30/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
L4 (θ) depends on s4
L1 (θ) L2 (θ) L3 (θ) L4 (θ)
s4 in turn depends on s3 and W
s3 in turn depends on s2 and W
s2 in turn depends on s1 and W
V V V V s1 in turn depends on s0 and W
s1
W s2
W s3
W s4
W ... where s0 is a constant starting state.
U U U U
x1 x2 x3 x4
s0 s1 s2 s3 s4 L4 (θ)
31/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
What we have here is an ordered net-
L1 (θ) L2 (θ) L3 (θ) L4 (θ) work
In an ordered network each state vari-
able is computed one at a time in a
V V V V
specified order (first s1 , then s2 and
W W W W ...
s1 s2 s3 s4 so on)
U U U U Now we have
∂L4 (θ) ∂L4 (θ) ∂s4
x1 x2 x3 x4 =
∂W ∂s4 ∂W
s0 s1 s2 s3 s4 L4 (θ)
We have already seen how to compute
∂L4 (θ)
∂s4 when we studied backprop
∂s4
W But how do we compute ∂W
32/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Recall that
L1 (θ) L2 (θ) L3 (θ) L4 (θ)
s4 = σ(W s3 + b)
33/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
∂s4 ∂ + s4 ∂s4 ∂s3
= +
∂W |∂W
{z } |∂s3{z∂W}
explicit implicit
∂ + s4 ∂s4 h ∂ + s3 ∂s3 ∂s2 i
= + +
∂W ∂s3 |∂W ∂s ∂W
{z } | 2{z }
explicit implicit
∂ + s4 ∂s4 ∂ + s3
∂s4 ∂s3 h ∂ + s2 ∂s2 ∂s1 i
= + + +
∂W ∂s3 ∂W ∂s3 ∂s2 ∂W ∂s1 ∂W
∂ s4 ∂s4 ∂ s3 ∂s4 ∂s3 ∂ s2 ∂s4 ∂s3 ∂s2 h ∂ + s1 i
+ + +
= + + +
∂W ∂s3 ∂W ∂s3 ∂s2 ∂W ∂s3 ∂s2 ∂s1 ∂W
For simplicity we will short-circuit some of the paths
4
∂s4 ∂s4 ∂ + s4 ∂s4 ∂ + s3 ∂s4 ∂ + s2 ∂s4 ∂ + s1 X ∂s4 ∂ + sk
= + + + =
∂W ∂s4 ∂W ∂s3 ∂W ∂s2 ∂W ∂s1 ∂W ∂sk ∂W
k=1
34/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Finally we have
L1 (θ) L2 (θ) L3 (θ) L4 (θ)
∂L4 (θ) ∂L4 (θ) ∂s4
=
∂W ∂s4 ∂W
V V V V 4
X ∂s4 ∂ + sk
∂s4
W W W W =
s1 s2 s3 s4 ... ∂W ∂sk ∂W
k=1
t
U U U U ∂Lt (θ) ∂Lt (θ) X ∂st ∂ + sk
∴ =
∂W ∂st ∂sk ∂W
x1 x2 x3 x4
k=1
35/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Module 14.4: The problem of Exploding and Vanishing
Gradients
36/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
∂st
We will now focus on ∂s k
and high-
light an important problem in train-
ing RNN’s using BPTT
37/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
∂sj
We are interested in ∂sj−1
∂s
j1 ∂sj2 ∂sj3
...
∂sj ∂sj ∂aj
=
∂sj ∂aj1 ∂aj1 ∂aj1
.. ∂sj−1 ∂aj ∂sj−1
∂sj1 ∂sj2
.
=
0
∂aj2 ∂aj2
∂aj = diag(σ (aj ))W
.. .. ..
∂sjd
. . . ∂ajd
0
σ (aj1 ) 0 0 0
0
0 σ (aj2 ) 0 0
=
.
We are interested in the magnitude
0 0 . .
∂sj ∂st
0
of ∂sj−1 ← if it is small (large) ∂s k
0 0 . . . σ (ajd ) ∂Lt
and hence ∂W will vanish (explode)
0
= diag(σ (aj ))
38/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
∂sj 0 t
= diag(σ (aj ))W ∂st Y ∂sj
∂sj−1 =
∂sk ∂sj−1
0 j=k+1
≤ diag(σ (aj ) kW k
t
Y
≤ γλ
∵ σ(aj ) is a bounded function (sigmoid,
0 j=k+1
tanh) σ (aj ) is bounded
≤ (γλ)t−k
0 1
σ (aj ) ≤ = γ [if σ is logistic ]
4 If γλ < 1 the gradient will vanish
≤ 1 = γ [if σ is tanh ]
If γλ > 1 the gradient could explode
∂sj
≤ γ kW k This is known as the problem of
∂sj−1 vanishing/ exploding gradients
≤ γλ
39/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
One simple way of avoiding this is to
use truncated backpropogation
where we restrict the product to
τ (< t − k) terms
Lt
y1 y2 y3 y4 yn
v v v v v
w w w w
u u u u u
x1 x2 x3 x4 xn
40/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Module 14.5: Some Gory Details
41/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
t
∂Lt (θ) ∂Lt (θ) X ∂st ∂ + sk
=
| ∂W
{z } ∂s ∂sk |∂W
| {zt } k=1 |{z} {z }
∈Rd×d ∈R1×d ∈Rd×d ∈R
d×d×d
42/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
∂ + sk
We just look at one element of this ∂W tensor
∂ + skp
∂Wqr is the (p, q, r)-th element of the 3d tensor
ak = W sk−1 + b
sk = σ(ak )
43/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14
Pd
ak = W sk−1 ∂akp ∂ i=1Wpi sk−1,i
=
ak1 W11 W12 ... W1d
sk−1,1
∂Wqr ∂Wqr
ak2 sk−1,2 = sk−1,i if p=q and i = r
.. .. .. .. .. ..
. . . . . . =0 otherwise
=
akp Wp1 Wp2 . . . Wpd sk−1,p ∂skp
= σ 0 (akp )sk−1,r
. . if p = q and i = r
.. .. .. .. ∂Wqr
.. .. . . . .
akd sk−1,d =0 otherwise
d
X
akp = Wpi sk−1,i
i=1
skp = σ(akp )
∂skp ∂skp ∂akp
=
∂Wqr ∂akp ∂Wqr
∂akp
= σ 0 (akp )
∂Wqr
44/44
Mitesh M. Khapra CS7015 (Deep Learning) : Lecture 14