Skip to content
This repository was archived by the owner on Aug 21, 2020. It is now read-only.

Commit abdf643

Browse files
authored
Merge pull request #46 from mlomnitz/master
Personalized transformation and functionality
2 parents be1abdc + 418b0c3 commit abdf643

File tree

2 files changed

+145
-4
lines changed

2 files changed

+145
-4
lines changed

Utils/models.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torch import nn
33
import torch.nn.functional as F
44
import numpy as np
5-
5+
import os.path
66

77
def new_size_conv(size, kernel, stride=1, padding=0):
88
return np.floor((size + 2*padding - (kernel -1)-1)/stride +1)
@@ -272,7 +272,85 @@ def forward(self, x):
272272

273273
return out
274274

275-
275+
276+
class audio_cnn_block(nn.Module):
277+
'''
278+
1D convolution block used to build audio cnn classifiers
279+
Args:
280+
input: input channels
281+
output: output channels
282+
kernel_size: convolution kernel size
283+
'''
284+
def __init__(self, n_input, n_out, kernel_size):
285+
super(audio_cnn_block, self).__init__()
286+
self.cnn_block = nn.Sequential(
287+
nn.Conv1d(n_input, n_out, kernel_size, padding=1),
288+
nn.BatchNorm1d(n_out),
289+
nn.ReLU(),
290+
nn.MaxPool1d(kernel_size=4, stride=4)
291+
)
292+
293+
def forward(self, x):
294+
return self.cnn_block(x)
295+
296+
297+
class audio_tiny_cnn(nn.Module):
298+
'''
299+
Template for convolutional audio classifiers.
300+
'''
301+
def __init__(self, cnn_sizes, n_hidden, kernel_size, n_classes):
302+
'''
303+
Init
304+
Args:
305+
cnn_sizes: List of sizes for the convolution blocks
306+
n_hidden: number of hidden units in the first fully connected layer
307+
kernel_size: convolution kernel size
308+
n_classes: number of speakers to classify
309+
'''
310+
super(audio_tiny_cnn, self).__init__()
311+
self.down_path = nn.ModuleList()
312+
self.down_path.append(audio_cnn_block(cnn_sizes[0], cnn_sizes[1],
313+
kernel_size,))
314+
self.down_path.append(audio_cnn_block(cnn_sizes[1], cnn_sizes[2],
315+
kernel_size,))
316+
self.down_path.append(audio_cnn_block(cnn_sizes[2], cnn_sizes[3],
317+
kernel_size,))
318+
self.fc = nn.Sequential(
319+
nn.Linear(cnn_sizes[4], n_hidden),
320+
nn.ReLU()
321+
)
322+
self.out = nn.Linear(n_hidden, n_classes)
323+
324+
def forward(self, x):
325+
for down in self.down_path:
326+
x = down(x)
327+
x = x.view(x.size(0), -1)
328+
x = self.fc(x)
329+
return self.out(x)
330+
331+
332+
def MFCC_cnn_classifier(n_classes):
333+
'''
334+
Builds speaker classifier that ingests MFCC's
335+
'''
336+
in_size = 20
337+
n_hidden = 512
338+
sizes_list = [in_size, 2*in_size, 4*in_size, 8*in_size, 8*in_size]
339+
return audio_tiny_cnn(cnn_sizes=sizes_list, n_hidden=n_hidden,
340+
kernel_size=3, n_classes=125)
341+
342+
343+
def ft_cnn_classifer(n_classes):
344+
'''
345+
Builds speaker classifier that ingests the abs value of fourier transforms
346+
'''
347+
in_size = 94
348+
n_hidden = 512
349+
sizes_list = [in_size, in_size, 2*in_size, 4*in_size, 14*4*in_size]
350+
return audio_tiny_cnn(cnn_sizes=sizes_list, n_hidden=n_hidden,
351+
kernel_size=7, n_classes=125)
352+
353+
276354
def weights_init(m):
277355
if isinstance(m, nn.Conv2d):
278356
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
@@ -285,8 +363,10 @@ def weights_init(m):
285363
nn.init.xavier_normal_(m.weight.data)
286364
nn.init.constant_(m.bias, 0)
287365

288-
def save_checkpoint(model = None, optimizer = None, epoch = None, data_descriptor = None, loss = None,
289-
accuracy = None, path = './', filename='checkpoint', ext = '.pth.tar'):
366+
367+
def save_checkpoint(model=None, optimizer=None, epoch=None,
368+
data_descriptor=None, loss=None, accuracy=None, path='./',
369+
filename='checkpoint', ext='.pth.tar'):
290370
state = {
291371
'epoch': epoch,
292372
'arch': str(model.type),
@@ -297,3 +377,16 @@ def save_checkpoint(model = None, optimizer = None, epoch = None, data_descripto
297377
'dataset': data_descriptor
298378
}
299379
torch.save(state, path+filename+ext)
380+
381+
382+
def load_checkpoint(model=None, optimizer=None, checkpoint=None):
383+
assert os.path.isfile(checkpoint), 'Checkpoint not found, aborting load'
384+
chpt = torch.load(checkpoint)
385+
assert str(model.type) == chpt['arch'], 'Model arquitecture mismatch,\
386+
aborting load'
387+
model.load_state_dict(chpt['state_dict'])
388+
if optimizer is not None:
389+
optimizer.load_state_dict['optimizer']
390+
print('Succesfully loaded checkpoint \nDataset: %s \nEpoch: %s \nLoss: %s\
391+
\nAccuracy: %s' % (chpt['dataset'], chpt['epoch'], chpt['loss'],
392+
chpt['accuracy']))

Utils/transformations.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
import librosa as libr
3+
import numpy as np
4+
5+
6+
class ToMFCC:
7+
'''
8+
Transformation to convert soundfile loaded via LibriSpeechDataset to Mel-
9+
frequency cepstral coefficients (MFCCs)
10+
Args:
11+
number_of_mels: Number of bins to use for cepstral coefficients
12+
Returns:
13+
torch.float tensor
14+
'''
15+
def __init__(self, number_of_mels=128):
16+
self.number_of_mels = number_of_mels
17+
18+
def __call__(self, y):
19+
dims = y.shape
20+
y = libr.feature.melspectrogram(np.reshape(y, (dims[1],)), 16000,
21+
n_mels=self.number_of_mels, fmax=8000)
22+
y = libr.feature.mfcc(S=libr.power_to_db(y))
23+
y = torch.from_numpy(y)
24+
return y.float()
25+
26+
27+
class STFT:
28+
'''
29+
Short-time Fourier transform (STFT) for librosa dataset
30+
Args:
31+
phase: If true, will return the magnitude and phase of the transformation,
32+
if false only returns magnitude
33+
Returns:
34+
torch.float tensor
35+
'''
36+
def __init__(self, phase=False):
37+
self.phase = phase
38+
39+
def __call__(self, y):
40+
dims = y.shape
41+
y = libr.core.stft(np.reshape(y, (dims[1],)))
42+
y, phase = np.abs(y), np.angle(y)
43+
y = torch.from_numpy(y).permute(1, 0)
44+
phase = torch.from_numpy(phase).permute(1, 0)
45+
if self.phase:
46+
return torch.cat( (y, phase), dim=0).float()
47+
else:
48+
return y.float()

0 commit comments

Comments
 (0)