25
25
#
26
26
# *****************************************************************************
27
27
import torch
28
- from torch .autograd import Variable
29
28
import torch .nn .functional as F
29
+ from torch .autograd import Variable
30
30
31
31
32
32
@torch .jit .script
@@ -48,11 +48,12 @@ class Invertible1x1Conv(torch.nn.Module):
48
48
49
49
def __init__ (self , c ):
50
50
super (Invertible1x1Conv , self ).__init__ ()
51
- self .conv = torch .nn .Conv1d (c , c , kernel_size = 1 , stride = 1 , padding = 0 ,
52
- bias = False )
51
+ self .conv = torch .nn .Conv1d (
52
+ c , c , kernel_size = 1 , stride = 1 , padding = 0 , bias = False
53
+ )
53
54
54
55
# Sample a random orthonormal matrix to initialize weights
55
- W = torch .qr (torch .FloatTensor (c , c ).normal_ ())[0 ]
56
+ W = torch .linalg . qr (torch .FloatTensor (c , c ).normal_ ())[0 ]
56
57
57
58
# Ensure determinant is 1.0 not -1.0
58
59
if torch .det (W ) < 0 :
@@ -67,18 +68,25 @@ def forward(self, z, reverse=False):
67
68
W = self .conv .weight .squeeze ()
68
69
69
70
if reverse :
70
- if not hasattr (self , ' W_inverse' ):
71
+ if not hasattr (self , " W_inverse" ):
71
72
# Reverse computation
72
73
W_inverse = W .float ().inverse ()
73
74
W_inverse = Variable (W_inverse [..., None ])
74
- if z .type () == 'torch.cuda.HalfTensor' or z .type () == 'torch.HalfTensor' :
75
+ if (
76
+ z .type () == "torch.cuda.HalfTensor"
77
+ or z .type () == "torch.HalfTensor"
78
+ ):
75
79
W_inverse = W_inverse .half ()
76
80
self .W_inverse = W_inverse
77
81
z = F .conv1d (z , self .W_inverse , bias = None , stride = 1 , padding = 0 )
78
82
return z
79
83
else :
80
84
# Forward computation
81
- log_det_W = batch_size * n_of_groups * torch .logdet (W .unsqueeze (0 ).float ()).squeeze ()
85
+ log_det_W = (
86
+ batch_size
87
+ * n_of_groups
88
+ * torch .logdet (W .unsqueeze (0 ).float ()).squeeze ()
89
+ )
82
90
z = self .conv (z )
83
91
return z , log_det_W
84
92
@@ -90,19 +98,20 @@ class WN(torch.nn.Module):
90
98
also no dilation size reset. The dilation only doubles on each layer
91
99
"""
92
100
93
- def __init__ (self , n_in_channels , n_mel_channels , n_layers , n_channels ,
94
- kernel_size ):
101
+ def __init__ (
102
+ self , n_in_channels , n_mel_channels , n_layers , n_channels , kernel_size
103
+ ):
95
104
super (WN , self ).__init__ ()
96
- assert ( kernel_size % 2 == 1 )
97
- assert ( n_channels % 2 == 0 )
105
+ assert kernel_size % 2 == 1
106
+ assert n_channels % 2 == 0
98
107
self .n_layers = n_layers
99
108
self .n_channels = n_channels
100
109
self .in_layers = torch .nn .ModuleList ()
101
110
self .res_skip_layers = torch .nn .ModuleList ()
102
111
self .cond_layers = torch .nn .ModuleList ()
103
112
104
113
start = torch .nn .Conv1d (n_in_channels , n_channels , 1 )
105
- start = torch .nn .utils .weight_norm (start , name = ' weight' )
114
+ start = torch .nn .utils .weight_norm (start , name = " weight" )
106
115
self .start = start
107
116
108
117
# Initializing last layer to 0 makes the affine coupling layers
@@ -113,15 +122,20 @@ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
113
122
self .end = end
114
123
115
124
for i in range (n_layers ):
116
- dilation = 2 ** i
125
+ dilation = 2 ** i
117
126
padding = int ((kernel_size * dilation - dilation ) / 2 )
118
- in_layer = torch .nn .Conv1d (n_channels , 2 * n_channels , kernel_size ,
119
- dilation = dilation , padding = padding )
120
- in_layer = torch .nn .utils .weight_norm (in_layer , name = 'weight' )
127
+ in_layer = torch .nn .Conv1d (
128
+ n_channels ,
129
+ 2 * n_channels ,
130
+ kernel_size ,
131
+ dilation = dilation ,
132
+ padding = padding ,
133
+ )
134
+ in_layer = torch .nn .utils .weight_norm (in_layer , name = "weight" )
121
135
self .in_layers .append (in_layer )
122
136
123
137
cond_layer = torch .nn .Conv1d (n_mel_channels , 2 * n_channels , 1 )
124
- cond_layer = torch .nn .utils .weight_norm (cond_layer , name = ' weight' )
138
+ cond_layer = torch .nn .utils .weight_norm (cond_layer , name = " weight" )
125
139
self .cond_layers .append (cond_layer )
126
140
127
141
# last one is not necessary
@@ -130,8 +144,7 @@ def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels,
130
144
else :
131
145
res_skip_channels = n_channels
132
146
res_skip_layer = torch .nn .Conv1d (n_channels , res_skip_channels , 1 )
133
- res_skip_layer = torch .nn .utils .weight_norm (
134
- res_skip_layer , name = 'weight' )
147
+ res_skip_layer = torch .nn .utils .weight_norm (res_skip_layer , name = "weight" )
135
148
self .res_skip_layers .append (res_skip_layer )
136
149
137
150
def forward (self , forward_input ):
@@ -142,12 +155,13 @@ def forward(self, forward_input):
142
155
acts = fused_add_tanh_sigmoid_multiply (
143
156
self .in_layers [i ](audio ),
144
157
self .cond_layers [i ](spect ),
145
- torch .IntTensor ([self .n_channels ]))
158
+ torch .IntTensor ([self .n_channels ]),
159
+ )
146
160
147
161
res_skip_acts = self .res_skip_layers [i ](acts )
148
162
if i < self .n_layers - 1 :
149
- audio = res_skip_acts [:, :self .n_channels , :] + audio
150
- skip_acts = res_skip_acts [:, self .n_channels :, :]
163
+ audio = res_skip_acts [:, : self .n_channels , :] + audio
164
+ skip_acts = res_skip_acts [:, self .n_channels :, :]
151
165
else :
152
166
skip_acts = res_skip_acts
153
167
@@ -159,14 +173,15 @@ def forward(self, forward_input):
159
173
160
174
161
175
class WaveGlow (torch .nn .Module ):
162
- def __init__ (self , n_mel_channels , n_flows , n_group , n_early_every ,
163
- n_early_size , WN_config ):
176
+ def __init__ (
177
+ self , n_mel_channels , n_flows , n_group , n_early_every , n_early_size , WN_config
178
+ ):
164
179
super (WaveGlow , self ).__init__ ()
165
180
166
- self .upsample = torch .nn .ConvTranspose1d (n_mel_channels ,
167
- n_mel_channels ,
168
- 1024 , stride = 256 )
169
- assert ( n_group % 2 == 0 )
181
+ self .upsample = torch .nn .ConvTranspose1d (
182
+ n_mel_channels , n_mel_channels , 1024 , stride = 256
183
+ )
184
+ assert n_group % 2 == 0
170
185
self .n_flows = n_flows
171
186
self .n_group = n_group
172
187
self .n_early_every = n_early_every
@@ -196,9 +211,9 @@ def forward(self, forward_input):
196
211
197
212
# Upsample spectrogram to size of audio
198
213
spect = self .upsample (spect )
199
- assert ( spect .size (2 ) >= audio .size (1 ) )
214
+ assert spect .size (2 ) >= audio .size (1 )
200
215
if spect .size (2 ) > audio .size (1 ):
201
- spect = spect [:, :, :audio .size (1 )]
216
+ spect = spect [:, :, : audio .size (1 )]
202
217
203
218
spect = spect .unfold (2 , self .n_group , self .n_group ).permute (0 , 2 , 1 , 3 )
204
219
spect = spect .contiguous ().view (spect .size (0 ), spect .size (1 ), - 1 )
@@ -211,8 +226,8 @@ def forward(self, forward_input):
211
226
212
227
for k in range (self .n_flows ):
213
228
if k % self .n_early_every == 0 and k > 0 :
214
- output_audio .append (audio [:, :self .n_early_size , :])
215
- audio = audio [:, self .n_early_size :, :]
229
+ output_audio .append (audio [:, : self .n_early_size , :])
230
+ audio = audio [:, self .n_early_size :, :]
216
231
217
232
audio , log_det_W = self .convinv [k ](audio )
218
233
log_det_W_list .append (log_det_W )
@@ -233,7 +248,6 @@ def forward(self, forward_input):
233
248
return torch .cat (output_audio , 1 ), log_s_list , log_det_W_list
234
249
235
250
def infer (self , spect , sigma = 1.0 ):
236
-
237
251
spect = self .upsample (spect )
238
252
# trim conv artifacts. maybe pad spec to kernel multiple
239
253
time_cutoff = self .upsample .kernel_size [0 ] - self .upsample .stride [0 ]
@@ -243,9 +257,9 @@ def infer(self, spect, sigma=1.0):
243
257
spect = spect .contiguous ().view (spect .size (0 ), spect .size (1 ), - 1 )
244
258
spect = spect .permute (0 , 2 , 1 )
245
259
246
- audio = torch .randn (spect . size ( 0 ),
247
- self .n_remaining_channels ,
248
- spect . size ( 2 ), device = spect . device ).to (spect .dtype )
260
+ audio = torch .randn (
261
+ spect . size ( 0 ), self .n_remaining_channels , spect . size ( 2 ), device = spect . device
262
+ ).to (spect .dtype )
249
263
250
264
audio = torch .autograd .Variable (sigma * audio )
251
265
@@ -263,16 +277,14 @@ def infer(self, spect, sigma=1.0):
263
277
audio = self .convinv [k ](audio , reverse = True )
264
278
265
279
if k % self .n_early_every == 0 and k > 0 :
266
- z = torch .randn (spect .size (0 ), self .n_early_size , spect .size (
267
- 2 ), device = spect .device ).to (spect .dtype )
280
+ z = torch .randn (
281
+ spect .size (0 ), self .n_early_size , spect .size (2 ), device = spect .device
282
+ ).to (spect .dtype )
268
283
audio = torch .cat ((sigma * z , audio ), 1 )
269
284
270
- audio = audio .permute (
271
- 0 , 2 , 1 ).contiguous ().view (
272
- audio .size (0 ), - 1 ).data
285
+ audio = audio .permute (0 , 2 , 1 ).contiguous ().view (audio .size (0 ), - 1 ).data
273
286
return audio
274
287
275
-
276
288
@staticmethod
277
289
def remove_weightnorm (model ):
278
290
waveglow = model
0 commit comments