22from torch import nn
33import torch .nn .functional as F
44import numpy as np
5-
5+ import os . path
66
77def new_size_conv (size , kernel , stride = 1 , padding = 0 ):
88 return np .floor ((size + 2 * padding - (kernel - 1 )- 1 )/ stride + 1 )
@@ -272,7 +272,85 @@ def forward(self, x):
272272
273273 return out
274274
275-
275+
276+ class audio_cnn_block (nn .Module ):
277+ '''
278+ 1D convolution block used to build audio cnn classifiers
279+ Args:
280+ input: input channels
281+ output: output channels
282+ kernel_size: convolution kernel size
283+ '''
284+ def __init__ (self , n_input , n_out , kernel_size ):
285+ super (audio_cnn_block , self ).__init__ ()
286+ self .cnn_block = nn .Sequential (
287+ nn .Conv1d (n_input , n_out , kernel_size , padding = 1 ),
288+ nn .BatchNorm1d (n_out ),
289+ nn .ReLU (),
290+ nn .MaxPool1d (kernel_size = 4 , stride = 4 )
291+ )
292+
293+ def forward (self , x ):
294+ return self .cnn_block (x )
295+
296+
297+ class audio_tiny_cnn (nn .Module ):
298+ '''
299+ Template for convolutional audio classifiers.
300+ '''
301+ def __init__ (self , cnn_sizes , n_hidden , kernel_size , n_classes ):
302+ '''
303+ Init
304+ Args:
305+ cnn_sizes: List of sizes for the convolution blocks
306+ n_hidden: number of hidden units in the first fully connected layer
307+ kernel_size: convolution kernel size
308+ n_classes: number of speakers to classify
309+ '''
310+ super (audio_tiny_cnn , self ).__init__ ()
311+ self .down_path = nn .ModuleList ()
312+ self .down_path .append (audio_cnn_block (cnn_sizes [0 ], cnn_sizes [1 ],
313+ kernel_size ,))
314+ self .down_path .append (audio_cnn_block (cnn_sizes [1 ], cnn_sizes [2 ],
315+ kernel_size ,))
316+ self .down_path .append (audio_cnn_block (cnn_sizes [2 ], cnn_sizes [3 ],
317+ kernel_size ,))
318+ self .fc = nn .Sequential (
319+ nn .Linear (cnn_sizes [4 ], n_hidden ),
320+ nn .ReLU ()
321+ )
322+ self .out = nn .Linear (n_hidden , n_classes )
323+
324+ def forward (self , x ):
325+ for down in self .down_path :
326+ x = down (x )
327+ x = x .view (x .size (0 ), - 1 )
328+ x = self .fc (x )
329+ return self .out (x )
330+
331+
332+ def MFCC_cnn_classifier (n_classes ):
333+ '''
334+ Builds speaker classifier that ingests MFCC's
335+ '''
336+ in_size = 20
337+ n_hidden = 512
338+ sizes_list = [in_size , 2 * in_size , 4 * in_size , 8 * in_size , 8 * in_size ]
339+ return audio_tiny_cnn (cnn_sizes = sizes_list , n_hidden = n_hidden ,
340+ kernel_size = 3 , n_classes = 125 )
341+
342+
343+ def ft_cnn_classifer (n_classes ):
344+ '''
345+ Builds speaker classifier that ingests the abs value of fourier transforms
346+ '''
347+ in_size = 94
348+ n_hidden = 512
349+ sizes_list = [in_size , in_size , 2 * in_size , 4 * in_size , 14 * 4 * in_size ]
350+ return audio_tiny_cnn (cnn_sizes = sizes_list , n_hidden = n_hidden ,
351+ kernel_size = 7 , n_classes = 125 )
352+
353+
276354def weights_init (m ):
277355 if isinstance (m , nn .Conv2d ):
278356 nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
@@ -285,8 +363,10 @@ def weights_init(m):
285363 nn .init .xavier_normal_ (m .weight .data )
286364 nn .init .constant_ (m .bias , 0 )
287365
288- def save_checkpoint (model = None , optimizer = None , epoch = None , data_descriptor = None , loss = None ,
289- accuracy = None , path = './' , filename = 'checkpoint' , ext = '.pth.tar' ):
366+
367+ def save_checkpoint (model = None , optimizer = None , epoch = None ,
368+ data_descriptor = None , loss = None , accuracy = None , path = './' ,
369+ filename = 'checkpoint' , ext = '.pth.tar' ):
290370 state = {
291371 'epoch' : epoch ,
292372 'arch' : str (model .type ),
@@ -297,3 +377,16 @@ def save_checkpoint(model = None, optimizer = None, epoch = None, data_descripto
297377 'dataset' : data_descriptor
298378 }
299379 torch .save (state , path + filename + ext )
380+
381+
382+ def load_checkpoint (model = None , optimizer = None , checkpoint = None ):
383+ assert os .path .isfile (checkpoint ), 'Checkpoint not found, aborting load'
384+ chpt = torch .load (checkpoint )
385+ assert str (model .type ) == chpt ['arch' ], 'Model arquitecture mismatch,\
386+ aborting load'
387+ model .load_state_dict (chpt ['state_dict' ])
388+ if optimizer is not None :
389+ optimizer .load_state_dict ['optimizer' ]
390+ print ('Succesfully loaded checkpoint \n Dataset: %s \n Epoch: %s \n Loss: %s\
391+ \n Accuracy: %s' % (chpt ['dataset' ], chpt ['epoch' ], chpt ['loss' ],
392+ chpt ['accuracy' ]))
0 commit comments