-
Notifications
You must be signed in to change notification settings - Fork 1
/
sashimi.py
359 lines (311 loc) · 12.5 KB
/
sashimi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
import torch
from torch import nn
from S4 import *
from torchaudio.functional import mu_law_encoding
from tqdm.auto import tqdm
def S4Block(signal_dim: int, state_dim: int, sequence_length: int, expansion_factor: int = 2):
"""
Construct the full S4 block given in SaShiMi paper. Arguments:
- signal_dim: Number of dimensions in the signal.
- state_dim: Number of dimensions in inner state.
- sequence_length: The length of the sequence on which this model will operate.
- Can be changed later, but models trained on one sequence length perform poorly
on another sequence length.
- expansion_factor: The factor by which the number of dimensions will be multiplied
between two linear layers in the second pass.
High-level Architecture
-----------------------
The architecture is as described in Appendix A.2 of "It’s Raw! Audio Generation with
State-Space Models" paper.
First pass:
1. Input
2. LayerNorm
3. S4 Layer
4. GELU
5. Linear
6. Residual connection from 1
Second pass:
1. Output of the first pass
2. LayerNorm
3. Linear
4. GELU
5. Linear
6. Residual connection from 1
All linear layers are position-wise, i.e., they operate on the signal dimensions, not
the time dimension.
"""
return Sequential(
Residual(
nn.LayerNorm(signal_dim),
S4Base(signal_dim, state_dim, sequence_length),
nn.GELU(),
nn.Linear(signal_dim, signal_dim),
),
Residual(
nn.LayerNorm(signal_dim),
nn.Linear(signal_dim, signal_dim * expansion_factor),
nn.GELU(),
nn.Linear(signal_dim * expansion_factor, signal_dim),
),
)
def S4BlockGLU(signal_dim: int, state_dim: int, sequence_length: int, expansion_factor: int = 2):
"""
Same as S4Block, but all activations (GELU) are replaced with a GLU layer. Since GLU halves
the signal dimensions, the output dimensions of the linear layers that precede GLU layers
are multiplied by 2.
See Appendix C.2.1 in "It's Raw! Audio Generation with State-Space Models":
> On SC09, we found that swapping in a gated linear unit (GLU) in the S4 block improved
> NLL as well as sample quality.
"""
return Sequential(
Residual(
nn.LayerNorm(signal_dim),
S4Base(signal_dim, state_dim, sequence_length),
nn.Linear(signal_dim, signal_dim * 2),
nn.GLU(), # GLU halves the last dimension
),
Residual(
nn.LayerNorm(signal_dim),
nn.Linear(signal_dim, signal_dim * expansion_factor * 2),
nn.GLU(), # GLU halves the last dimension
nn.Linear(signal_dim * expansion_factor, signal_dim),
),
)
class DownPool(nn.Module):
"""
Let p be the pooling factor and q the expansion factor.
The down-pooling operation is:
reshape linear
(T,H) ---------> (T/p, H * p) --------> (T/p, H * q)
Preserves the dimensions when combined with an UpPool layer with same settings.
"""
def __init__(self, signal_dim: int, pooling_factor: int = 4, expansion_factor: int = 2):
"""
- signal_dim: Input signal dimensions.
- pooling_factor: Time is divided and hidden dimension is multiplied by this.
- expansion_factor: Ratio between the hidden dimension of output and input.
"""
super().__init__()
self.pooling_factor = pooling_factor
self.linear = nn.Linear(
signal_dim * pooling_factor,
signal_dim * expansion_factor,
)
def forward(self, x):
T = x.size(dim=-2)
H = x.size(dim=-1)
x = x.reshape(-1, T // self.pooling_factor, H * self.pooling_factor)
return self.linear(x)
def get_recurrent_runner(self):
raise TypeError("DownPool cannot be used in recurrent mode by itself. " +
"See CausalPooledResidual.")
class UpPool(nn.Module):
"""
Let p be the pooling factor and q the expansion factor.
The up-pooling operation is the opposite of the down-pooling operation:
linear reshape
(T/p, H * q) --------> (T/p, H * p) ---------> (T,H)
Preserves the dimensions when combined with a DownPool layer with same settings.
"""
def __init__(self, signal_dim: int, pooling_factor: int = 4, expansion_factor: int = 2):
"""
- signal_dim: Output signal dimensions.
- pooling_factor: Time is multiplied and hidden dimension is divided by this.
- expansion_factor: Ratio between the hidden dimension of input and output.
"""
super().__init__()
self.pooling_factor = pooling_factor
self.linear = nn.Linear(
signal_dim * expansion_factor,
signal_dim * pooling_factor,
)
def no_shift(self, x):
"""
Apply up-pooling without shifting.
Equivalent to calling the layer directly in UpPool.
"""
y = self.linear(x)
T = y.size(dim=-2)
H = y.size(dim=-1)
return y.reshape(-1, T * self.pooling_factor, H // self.pooling_factor)
def forward(self, x):
return self.no_shift(x)
def get_recurrent_runner(self):
raise TypeError("UpPool cannot be used in recurrent mode by itself. " +
"See CausalPooledResidual.")
class CausalUpPool(UpPool):
"""
Same as up-pooling, but shifts the input to the right and pads with zero in order to
preserve causality.
Note that regular up-pool breaks causality because the model can see the samples in
the future if they are in the same block.
"""
def forward(self, x):
# Use shifting to preserve causality
x = torch.nn.functional.pad(x[:, :-1, :], pad=(0, 0, 1, 0))
return self.no_shift(x)
class CausalPooledResidual(nn.Module):
"""
A sequential block wrapped between DownPool and UpPool layers with a residual connection
from its beginning to the end.
Equivalent to this for convolution:
Residual(
DownPool(hidden_dim),
*sequential_blocks,
UpPool(hidden_dim),
)
But is capable of running recurrently unlike the block above.
"""
def __init__(self,
layers,
signal_dim: int,
pooling_factor: int = 4,
expansion_factor: int = 2,
):
"""
- layers: List of layers in sequential block.
- signal_dim, pooling_factor, expansion_factor: Parameters for pooling layers.
"""
super().__init__()
self.sequential = Sequential(*layers)
self.down_pool = DownPool(signal_dim, pooling_factor, expansion_factor)
self.up_pool = CausalUpPool(signal_dim, pooling_factor, expansion_factor)
self.signal_dim = signal_dim
self.pooling_factor = pooling_factor
def forward(self, x):
return self.up_pool(self.sequential(self.down_pool(x))) + x
def get_recurrent_runner(self):
"""
Discretize the model with given L and return a function that maps state and input to
the new state and input.
"""
input_cache = []
# First block will be up-pooled from zeros due to shifting.
device = next(iter(self.parameters())).device
first = torch.zeros(1, self.up_pool.linear.in_features, device=device)
first = self.up_pool.no_shift(first)
# first shape: (1, self.pooling_factor, hidden)
output_cache = [i for i in first.squeeze(0)]
# Instead of removing items from the beginning, we reverse the list and pop from
# the end. This is slightly faster.
output_cache.reverse()
sequential = self.sequential.get_recurrent_runner()
def f(u):
nonlocal sequential, input_cache, output_cache
output = output_cache.pop() + u
input_cache.append(u)
if len(input_cache) == self.pooling_factor:
x = torch.cat(input_cache, dim=-2)
input_cache.clear()
output_cache.clear()
y = self.down_pool(x)
y = self.up_pool.no_shift(sequential(y))
output_cache = [i for i in torch.split(y, 1, dim=-2)]
output_cache.reverse()
return output
return f
class Embedding(torch.nn.Embedding):
pass
def SaShiMi(input_dim: int,
hidden_dim: int,
output_dim: int,
state_dim: int,
sequence_length: int,
block_count: int,
block_class=S4Block,
encoder=None,
decoder=None,
):
"""
Construct the SaShiMi architecture given in Figure 1 of "It’s Raw! Audio Generation with
State-Space Models" paper.
- input_dim: Input signal dimension.
- hidden_dim: Signal dimension in the S4 blocks.
- output_dim: Output signal dimension.
- state_dim, sequence_length: Parameters for S4 blocks.
- block_count: Number of S4 blocks in each series of S4 Blocks.
- block_class: S4 block class. Can be S4Block or S4BlockGLU.
- encoder: Optional encoder layer. A linear layer is constructed if not provided.
- decoder: Optional decoder layer. A linear layer is constructed if not provided.
"""
encoder = nn.Linear(input_dim, hidden_dim) if encoder is None else encoder
decoder = nn.Linear(hidden_dim, output_dim) if decoder is None else decoder
return Sequential(
encoder,
CausalPooledResidual(
signal_dim = hidden_dim,
layers = [
CausalPooledResidual(
signal_dim = hidden_dim * 2,
layers = [Residual(*[
block_class(4 * hidden_dim, state_dim, sequence_length // 16)
for _ in range(block_count)
])],
),
*[block_class(2 * hidden_dim, state_dim, sequence_length // 4)
for _ in range(block_count)],
],
),
*[block_class(hidden_dim, state_dim, sequence_length) for _ in range(block_count)],
decoder,
)
def generate_audio_sample(
model,
sample_count: int,
batch_size: int = 1,
priming_signal=None,
starting_input=None,
maxp=False,
use_tqdm=True,
):
"""
Generate an audio sample autoregressively from the model using 8-bit mu-law encoding.
- model: Autoregressive audio model.
- sample_count: Number of total samples in the output.
- batch_size: Number of generated audio files.
- priming_signal: Model will complete this signal if given.
The model will generate sample_count - priming_signal.size(0) samples.
The priming signal will be included in the output if provided.
- starting_input: Normally, one sample of silence will be given to the model to start
the generation. If this argument is given, it will be used instead.
- maxp: If true, the option with the highest probability will be selected instead of
random sampling.
- use_tqdm: Use tqdm library to display a progress bar.
Returns:
- A tensor of shape (batch_size, sample_count), containing samples in mu-law encoding.
"""
f = model.get_recurrent_runner()
# Pad the input with 0 sample to get started.
device = next(model.parameters()).device
if starting_input is None:
starting_input = mu_law_encoding(torch.zeros(batch_size, 1, device=device), 256)
u = f(starting_input)
# Process the priming signal if given
if priming_signal is not None:
for s in priming_signal:
u = f(s.reshape(1, -1).expand(batch_size, -1))
primed_size = priming_signal.size(0)
else:
primed_size = 0
# Generate the new part
Y = []
iterator = range(sample_count - primed_size)
# Don't use tqdm while testing
if use_tqdm:
iterator = tqdm(iterator, leave=False)
for _ in iterator:
if maxp:
p = torch.argmax(u, dim=-1)
else:
dist = torch.distributions.categorical.Categorical(
probs=torch.nn.functional.softmax(u, dim=-1),
)
p = dist.sample()
Y.append(p)
u = f(p)
generated = torch.cat(Y, dim=1)
if priming_signal is not None:
priming_signal = priming_signal.flatten()
return torch.cat([priming_signal.reshape(1, -1), generated], dim=1)
else:
return generated