Published: 06.12.2023
Vision Retention Networks
Introduction
Retention is a mechanism recently proposed in Retentive Network: A Successor to Transformer for Large Language Models by Sun et. al. which core idea is to carry out similar computation as attention while being much more computationally efficient. It has now become a recurrent pattern that researchers from other fields of machine learning take inspiration from the progress done in NLP and try to adapt NLP-solutions to a different problem. It was thus only a matter of time before we would have heard of Retention in the field of Computer Vision.
ViR: Vision Retention Networks by Ali Hatamizadeh, Michael Ranzinger, Jan Kautz first applied Retention in a CV model. I recently had a great time re-implementing the paper and digging into Retention, so I thought I would share what I have learned. You can find my re-implementation at brianpulfer/vision-Retention-networks or at brianpulfer/papersreimplementations.
Attention
Before we dig into ViR, we need to learn what Retention is. But before we learn what Retention is, a little recap on Attention.
Attention is a mechanism that allows a model to learn relationships between elements of the input. The meaning of a word can be completely altered based on the surrounding words. A red pixel in an image might come from a tomato, a Ferrari car or a baloon. Only the combination with neighbouring pixels give it a meaning. It is thus important for models to have the ability to learn the interplay of elements in the input sequence. That is where Attention comes in, and this is how it works:
Given an input sequence , attention computes Queries, Keys and Values for each element of the sequence as follows:
Where are learnable parameters. The output for each element of the sequence is going to be a weighted sum of the values, where the weights are computed as the dot product between the query and the keys:
and is applied row-wise.
This mechanism, ever since the Attention Is All You Need paper, has been empirically proven to be very powerful for learning relationships between elements of a sequence. It has been used in virtually all contexts (NLP, CV, TTS, ...), and it has become a de-facto standard for many tasks.
Then why getting rid of it?
There is only one issue that has researchers a bit troubled: the complexity of attention is (easily seen when computing ), meaning that for an input sequence twice as long, computing attention takes four times as much time.
Quite some effort went into trying to solve this issue, with various variations like Linear Attention and Efficient Attention trying to replicate the mechanism while being computationally more convenient.
Retention
Retention works recurrently just like recurrent neural networks. At each step, it reads the input to update an inner state matrix, use the inner state to compute an output and pass the inner state onward. Here is the RECURRENT formulation of Retention
where is the inner state at step , is the key and is the value of the current (n-th) element in the sequence (row vectors, so ). Needless to say, are linear projections of the n-th sequence element . Finally, is a constant that exponentially decays older key-values products.
Translating into text these equations, the idea is the following: will contain the state in the form of all key-value products. The ouput is obtained by fetching the desired value (mixture of values) by using the current query .
This is literally all there is to Retention! What is so special about it is that it can also be computed using a PARALLEL formulation just like we do for Attention. The formula to compute all outputs at once is the following:
Looks familiar, right? In fact, we do everything exactly as for Attention, except that we do not apply a row-wise softmax function and always apply , a lower-triangular matrix that simultaneously deals with causal masking (take no contribution from future elements in the sequence) and applies the exponential decay given by .
The key takeaway here is that if we get rid of the softmax operator we unlock the recurrent formulation, where we can just carry on what we had computed before to compute the next output.
However, processing sequences recurrently sucks! That is exactly the reason why we generally prefer Transformers over RNNs: the Transformer time complexity might be quadratic in sequence length, but at least we can process everything in parallel. With a recurrent formulation, we need to sequentially compute the n-th output before we can compute the n-th + 1 while our GPUs sit quiet.
Then why caring about a recurrent formulation?
The real ✨magic✨ happens when we decide to use a hybrid between parallel and recurrent formulations. In fact, it turns out that we can break the input sequence into multiple chunks, run each chunk in parallel using the parallel formulation, and then aggregate all of the results with a cross-chunk recurrent computation. This means that as soon as the sequence becomes prohibitively long for the parallel formulation (quadratic in ), we can just split it into chunks of size and run those parallelly (quadratic in chunk-size only!) and finally combine the cross-chunk information recurrently (linear in ). The real gain is thus obtained when we have very long sequences.
Here we have the CHUNKWISE RECURRENT formulation of Retention:
The math looks scary, but really we are just applying the parallel computation for all chunks and, once we have the Inner-Chunk parts, we can merge them using the recurrent formulation.
Comparison of Attention and Retention
Time complexity
Retention can, given the previous state, compute the next token in time complexity, whereas Attention does not have a previous state and it needs to use all past keys and queries to predict the next token.
Recurrent formulation
Attention does not need to be formulated recurrently, whereas Retention does. This is perfectly fine for causal decoder transformers, where we don not want current tokens to attend to future tokens anyways. However, in computer vision we mostly use the encoder type of transformer, so it is not completely clear what impact forcing the causal relationship might have in a task where seemingly there is no causal relationship.
Personal observation: Because Retention accumulates all keys and queries, I believe that it is probably not as powerful of a mechanism as Attention. Perhaps this loss of expressivity is not a big deal for text and/or images, especially compared to the gains made in time complexity, but this is still something to keep in mind. It might very well be that Retention fails to become a de-facto standard like other alternatives to Attention before it due to worse performances. What is sure is that Retention enables faster inference and, for very long sequences, even faster training while being quite similar to Attention.
Vision Retention Networks
Vision Retention Networks are a minor yet important variantion from Vision Transformer. I have previously written about how Vision Transformers (ViT) work, but in short, a ViT breaks an image into many distinct non-overlapping patches (typically, 16x16 patches of size 14x14 for images of size 224x224) which are then flattened and treated as a sequence. An encoder transformer is then used to process the sequence without any causal masking and the output is used for down-stream tasks.
The ViT is thus just a stack of encoder blocks, where each block sequentially applies an Attention block and an MLP block. In ViR, we get rid of the Attention block and swap a Retention block in instead.
Personal observation: It must be noted that because Retention works in a recurrent matter by definition, this is a big shift from ViT! While a ViT sees the whole image in one go, a ViR virtually reads the image from left to right from top to bottom. This is potentially a drawback of ViR over ViT, since it might not make sense to introduce causality in images.
Because retention reads the image in sequence, if we want our model to be a classifier, we need to use an output that comes after all tokens have been seen. To do so, we append a learnable [CLS] at the end of the sequence and use the generated output to do classification. Notice than in regular ViT, the CLS token was typically placed at the beginning of the sequence (although for a regular ViT this does not really make a difference).
Implementation
Here is my full re-implementation of a ViR:
1import torch
2import torch.nn as nn
3
4
5class ViRModes:
6 PARALLEL = "parallel"
7 RECURRENT = "recurrent"
8 CHUNKWISE = "chunkwise"
9
10
11class Retention(nn.Module):
12 def __init__(
13 self,
14 embed_dim,
15 max_len,
16 alpha,
17 mode=ViRModes.PARALLEL,
18 chunk_size=20,
19 ):
20 super(Retention, self).__init__()
21 self.dim = embed_dim
22 self.max_len = max_len
23 self.chunk_size = chunk_size
24 self.alpha = alpha
25 self.mode = mode
26
27 # Useful buffers
28 self.register_buffer("dim_sqrt", torch.tensor(embed_dim**0.5))
29
30 indices = torch.arange(max_len).reshape(1, -1)
31 self.register_buffer(
32 "decay_mask",
33 (alpha ** (indices.t() - indices)).tril(),
34 )
35
36 self.register_buffer("causal_mask", torch.ones(max_len, max_len).tril())
37 self.qkv = nn.Linear(embed_dim, embed_dim * 3)
38
39 def forward_parallel(self, x):
40 # Getting queries, keys, values
41 bs, sl, d = x.shape
42 qkv = self.qkv(x)
43 q, k, v = torch.chunk(qkv, 3, dim=-1)
44
45 # Causal and decay masking
46 M = (self.causal_mask[:sl, :sl] * self.decay_mask[:sl, :sl]).repeat(bs, 1, 1)
47
48 # Retention
49 out = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v
50
51 return out
52
53 def forward_recurrent(self, x, state):
54 batch_size, length, dim = x.shape
55
56 all_outputs = []
57 state = torch.zeros(batch_size, dim, dim, device=x.device)
58 for i in range(length):
59 xi = x[:, i]
60 q, k, v = self.qkv(xi).chunk(3, dim=-1)
61
62 state = self.alpha * state + k.unsqueeze(-1) @ v.unsqueeze(1)
63 out = q.unsqueeze(1) @ state / self.dim_sqrt
64 all_outputs.append(out.squeeze())
65
66 x = torch.stack(all_outputs, dim=1)
67 return x
68
69 def forward_chunkwise(self, x, chunk_size=None):
70 # Getting queries, keys, values
71 if chunk_size is None:
72 chunk_size = self.chunk_size
73
74 bs, sl, d = x.shape
75
76 # Adding dummy tokens to make the sequence length divisible by chunk_size
77 if sl % chunk_size != 0:
78 x = torch.cat(
79 [x, torch.zeros(bs, chunk_size - sl % chunk_size, d, device=x.device)],
80 dim=1,
81 )
82 n_chunks = x.shape[1] // chunk_size
83
84 # Running all chunks in parallel
85 x = x.reshape(bs, n_chunks, chunk_size, d)
86 q, k, v = self.qkv(x).chunk(3, dim=-1)
87
88 M = (
89 self.causal_mask[:chunk_size, :chunk_size]
90 * self.decay_mask[:chunk_size, :chunk_size]
91 ).repeat(bs, n_chunks, 1, 1)
92
93 inner_chunk = (q @ k.transpose(-1, -2) / self.dim_sqrt * M) @ v
94
95 # Updating outputs with chunk-wise recurrent
96 retention_mask = (
97 torch.tensor(
98 [self.alpha ** (chunk_size - i - 1) for i in range(chunk_size)],
99 device=x.device,
100 )
101 .repeat(bs, d, 1)
102 .transpose(-1, -2)
103 )
104
105 cross_mask = (
106 torch.tensor(
107 [self.alpha ** (i + 1) for i in range(chunk_size)], device=x.device
108 )
109 .repeat(bs, n_chunks, d, 1)
110 .transpose(-1, -2)
111 )
112
113 states = torch.zeros(bs, n_chunks, d, d, device=x.device)
114 for i in range(1, n_chunks):
115 chunk_state = k[:, i - 1].transpose(-1, -2) @ (v[:, i - 1] * retention_mask)
116 states[:, i] = chunk_state + states[:, i - 1] * self.alpha**chunk_size
117
118 cross_chunk = (q @ states) / self.dim_sqrt * cross_mask
119
120 # Combining inner and cross chunk
121 out = inner_chunk + cross_chunk
122
123 # Removing dummy tokens
124 out = out.flatten(1, 2)[:, :sl]
125 return out
126
127 def forward(self, x, state=None, mode=ViRModes.PARALLEL, chunk_size=None):
128 if mode is None:
129 mode = self.mode
130
131 if mode == ViRModes.PARALLEL:
132 return self.forward_parallel(x)
133 elif mode == ViRModes.RECURRENT:
134 return self.forward_recurrent(x, state)
135 elif mode == ViRModes.CHUNKWISE:
136 return self.forward_chunkwise(x, chunk_size)
137 else:
138 raise ValueError(f"Unknown mode {mode}")
139
140
141class MultiHeadRetention(nn.Module):
142 def __init__(
143 self,
144 heads,
145 embed_dim,
146 max_len,
147 alphas=None,
148 mode=ViRModes.PARALLEL,
149 chunk_size=20,
150 ):
151 super(MultiHeadRetention, self).__init__()
152 self.n_heads = heads
153 self.embed_dim = embed_dim
154 self.max_len = max_len
155 self.alphas = alphas
156 self.head_dim = embed_dim // heads
157 self.mode = mode
158 self.chunk_size = chunk_size
159
160 if alphas is None:
161 alphas = [1 - 2 ** (-5 - i) for i in range(heads)]
162
163 assert len(alphas) == heads, "Number of alphas must match number of heads"
164
165 assert (
166 embed_dim % heads == 0
167 ), "Embedding dimension must be divisible by the number of heads"
168
169 self.heads = nn.ModuleList(
170 [
171 Retention(embed_dim // heads, max_len, alpha, mode, chunk_size)
172 for alpha in alphas
173 ]
174 )
175 self.ln = nn.LayerNorm(embed_dim)
176 self.gelu = nn.GELU()
177 self.linear = nn.Linear(embed_dim, embed_dim)
178
179 def forward(self, x, mode=None, chunk_size=None):
180 if mode is None:
181 mode = self.mode
182
183 if chunk_size is None:
184 chunk_size = self.chunk_size
185
186 out = torch.cat(
187 [
188 head(
189 x[:, :, i * self.head_dim : (i + 1) * self.head_dim],
190 mode=mode,
191 chunk_size=chunk_size,
192 )
193 for i, head in enumerate(self.heads)
194 ],
195 dim=-1,
196 )
197 return self.linear(self.gelu(self.ln(out)))
198
199
200class MLP(nn.Module):
201 def __init__(self, embed_dim, hidden_dim=None):
202 super(MLP, self).__init__()
203
204 if hidden_dim is None:
205 hidden_dim = 4 * embed_dim
206
207 self.linear1 = nn.Linear(embed_dim, hidden_dim)
208 self.linear2 = nn.Linear(hidden_dim, embed_dim)
209 self.gelu = nn.GELU()
210
211 def forward(self, x):
212 return self.linear2(self.gelu(self.linear1(x)))
213
214
215class ViRBlock(nn.Module):
216 def __init__(
217 self,
218 heads,
219 embed_dim,
220 max_len,
221 alphas=None,
222 mode=ViRModes.PARALLEL,
223 chunk_size=20,
224 dropout=0.1,
225 ):
226 super(ViRBlock, self).__init__()
227 self.mode = mode
228 self.chunk_size = chunk_size
229
230 self.ln1 = nn.LayerNorm(embed_dim)
231 self.retention = MultiHeadRetention(
232 heads, embed_dim, max_len, alphas, mode, chunk_size
233 )
234 self.ln2 = nn.LayerNorm(embed_dim)
235 self.mlp = MLP(embed_dim)
236 self.dropout1 = nn.Dropout(dropout)
237 self.dropout2 = nn.Dropout(dropout)
238
239 def forward(self, x, mode=None, chunk_size=None):
240 if mode is None:
241 mode = self.mode
242
243 if chunk_size is None:
244 chunk_size = self.chunk_size
245
246 x = (
247 self.dropout1(self.retention(self.ln1(x), mode=mode, chunk_size=chunk_size))
248 + x
249 )
250 x = self.dropout2(self.mlp(self.ln2(x))) + x
251 return x
252
253
254class ViR(nn.Module):
255 def __init__(
256 self,
257 patch_size=14,
258 depth=12,
259 heads=12,
260 embed_dim=768,
261 max_len=256,
262 alphas=None,
263 mode=ViRModes.CHUNKWISE,
264 chunk_size=256,
265 dropout=0.1,
266 ):
267 super(ViR, self).__init__()
268
269 # Local parameters
270 self.out_dim = 10
271 self.patch_size = patch_size
272 self.depth = depth
273 self.heads = heads
274 self.embed_dim = embed_dim
275 self.max_len = max_len
276 self.alphas = alphas
277 self.mode = mode
278 self.chunk_size = chunk_size
279
280 # Embeddings
281 self.patch_embed = nn.Conv2d(
282 3, embed_dim, (patch_size, patch_size), stride=(patch_size, patch_size)
283 )
284 self.pos_embed = nn.Parameter(torch.randn(1, max_len, embed_dim))
285
286 # ViR blocks
287 self.blocks = nn.ModuleList(
288 [
289 ViRBlock(heads, embed_dim, max_len, alphas, mode, chunk_size, dropout)
290 for _ in range(depth)
291 ]
292 )
293
294 # Head
295 self.ln = nn.LayerNorm(embed_dim)
296
297 def forward(self, x, mode=None, chunk_size=None, reshape=False):
298 if mode is None:
299 mode = self.mode
300
301 if chunk_size is None:
302 chunk_size = self.chunk_size
303
304 # Patch embedding, positional embedding
305 x = self.patch_embed(x).permute(0, 2, 3, 1).flatten(1, 2)
306 bs, sl = x.shape[:2]
307 x = x + self.pos_embed.repeat(bs, 1, 1)[:, :sl]
308
309 # Blocks
310 for block in self.blocks:
311 x = block(x, mode=mode, chunk_size=chunk_size)
312
313 # Layer Norm
314 x = self.ln(x)
315
316 # Reshape
317 if reshape:
318 ps = int(x.shape[1] ** 0.5)
319 x = x.reshape(bs, ps, ps, self.embed_dim).permute(0, 3, 1, 2)
320
321 return x
322
323
324if __name__ == "__main__":
325 device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326 x = torch.randn(16, 3, 224, 224).to(device)
327 model = ViR(depth=12, heads=3, embed_dim=192).eval().to(device)
328
329 with torch.no_grad():
330 y1 = model(x, mode=ViRModes.PARALLEL)
331 y2 = model(x, mode=ViRModes.RECURRENT)
332 y3 = model(x, mode=ViRModes.CHUNKWISE, chunk_size=20)
333
334 assert torch.allclose(
335 y1, y2, atol=1e-5
336 ), "Parallel and recurrent modes should give the same output"
337
338 assert torch.allclose(
339 y1, y3, atol=1e-5
340 ), "Parallel and chunkwise modes should give the same output"
341
It feels like I should comment there 300+ lines, but really there is nothing that is not already covered in the formulas. The only thing that I should mention is that the chunk size might not entirely devide the sequence length , so what one can do is adding some dummy tokens at the end of the sequence such that the sequence is entirely divisible by the chunk size (a sort of padding).
Also, I found it key for performances to actually perform computations for all chunks in parallel, so it is not enough to re-use the forward_parallel function sequentially for each chunk.
Also notice that we use different alphas for each head: some heads with a higher alpha will look further back into the past, other heads with a lower alpha will mostly focus on most recent tokens.
Thank you for reading until here! If you found this helpful / interesting, or have suggestions on how to improve, please do not hesitate to contact me at me@brianpulfer.ch