vggish neural network -凯发k8网页登录
vggish neural network
since r2020b
syntax
description
examples
download vggish network
this example uses:
download and unzip the audio toolbox™ model for vggish.
type vggish
at the command window. if the audio toolbox model for vggish is not installed, then the function provides a link to the location of the network weights. to download the model, click the link. unzip the file to a location on the matlab path.
alternatively, execute these commands to download and unzip the vggish model to your temporary directory.
downloadfolder = fullfile(tempdir,'vggishdownload'); loc = websave(downloadfolder,'https://ssd.mathworks.com/supportfiles/audio/vggish.zip'); vggishlocation = tempdir; unzip(loc,vggishlocation) addpath(fullfile(vggishlocation,'vggish'))
check that the installation is successful by typing vggish
at the command window. if the network is installed, then the function returns a (deep learning toolbox) object.
vggish
ans = seriesnetwork with properties: layers: [24×1 nnet.cnn.layer.layer] inputnames: {'inputbatch'} outputnames: {'regressionoutput'}
load pretrained vggish network
this example uses:
load a pretrained vggish convolutional neural network and examine the layers and classes.
use vggish
to load the pretrained vggish network. the output net
is a (deep learning toolbox) object.
net = vggish
net = seriesnetwork with properties: layers: [24×1 nnet.cnn.layer.layer] inputnames: {'inputbatch'} outputnames: {'regressionoutput'}
view the network architecture using the layers
property. the network has 24 layers. there are nine layers with learnable weights, of which six are convolutional layers and three are fully connected layers.
net.layers
ans = 24×1 layer array with layers: 1 'inputbatch' image input 96×64×1 images 2 'conv1' convolution 64 3×3×1 convolutions with stride [1 1] and padding 'same' 3 'relu' relu relu 4 'pool1' max pooling 2×2 max pooling with stride [2 2] and padding 'same' 5 'conv2' convolution 128 3×3×64 convolutions with stride [1 1] and padding 'same' 6 'relu2' relu relu 7 'pool2' max pooling 2×2 max pooling with stride [2 2] and padding 'same' 8 'conv3_1' convolution 256 3×3×128 convolutions with stride [1 1] and padding 'same' 9 'relu3_1' relu relu 10 'conv3_2' convolution 256 3×3×256 convolutions with stride [1 1] and padding 'same' 11 'relu3_2' relu relu 12 'pool3' max pooling 2×2 max pooling with stride [2 2] and padding 'same' 13 'conv4_1' convolution 512 3×3×256 convolutions with stride [1 1] and padding 'same' 14 'relu4_1' relu relu 15 'conv4_2' convolution 512 3×3×512 convolutions with stride [1 1] and padding 'same' 16 'relu4_2' relu relu 17 'pool4' max pooling 2×2 max pooling with stride [2 2] and padding 'same' 18 'fc1_1' fully connected 4096 fully connected layer 19 'relu5_1' relu relu 20 'fc1_2' fully connected 4096 fully connected layer 21 'relu5_2' relu relu 22 'fc2' fully connected 128 fully connected layer 23 'embeddingbatch' relu relu 24 'regressionoutput' regression output mean-squared-error
use analyzenetwork
(deep learning toolbox) to visually explore the network.
analyzenetwork(net)
extract features using vggish
this example uses:
read in an audio signal to extract feature embeddings from it.
[audioin,fs] = audioread("ambiance-16-44p1-mono-12secs.wav");
plot and listen to the audio signal.
t = (0:numel(audioin)-1)/fs; plot(t,audioin) xlabel("time (s)") ylabel("ampltiude") axis tight
% to play the sound, call soundsc(audioin,fs)
vggish requires you to preprocess the audio signal to match the input format used to train the network. the preprocesssing steps include resampling the audio signal and computing an array of mel spectrograms. to learn more about mel spectrograms, see . use vggishpreprocess
to preprocess the signal and extract the mel spectrograms to be passed to vggish. visualize one of these spectrograms chosen at random.
spectrograms = vggishpreprocess(audioin,fs); arbitraryspect = spectrograms(:,:,1,randi(size(spectrograms,4))); surf(arbitraryspect,edgecolor="none") view(90,-90) xlabel("mel band") ylabel("frame") title("mel spectrogram for vggish") axis tight
create a vggish neural network. using the vggish
function requires installing the pretrained vggish network. if the network is not installed, the function provides a link to download the pretrained model.
net = vggish;
call predict
with the network on the preprocessed mel spectrogram images to extract feature embeddings. the feature embeddings are returned as a numframes
-by-128 matrix, where numframes
is the number of individual spectrograms and 128 is the number of elements in each feature vector.
features = predict(net,spectrograms); [numframes,numfeatures] = size(features)
numframes = 24
numfeatures = 128
visualize the vggish feature embeddings.
surf(features,edgecolor="none") view([90 -90]) xlabel("feature") ylabel("frame") title("vggish feature embeddings") axis tight
transfer learning using vggish
this example uses:
in this example, you transfer the learning in the vggish regression model to an audio classification task.
download and unzip the environmental sound classification data set. this data set consists of recordings labeled as one of 10 different audio sound classes (esc-10).
downloadfolder = matlab.internal.examples.downloadsupportfile("audio","esc-10.zip"); unzip(downloadfolder,tempdir) datalocation = fullfile(tempdir,"esc-10");
create an audiodatastore
object to manage the data and split it into train and validation sets. call to display the distribution of sound classes and the number of unique labels.
ads = audiodatastore(datalocation,includesubfolders=true,labelsource="foldernames");
labeltable = counteachlabel(ads)
labeltable=10×2 table
label count
______________ _____
chainsaw 40
clock_tick 40
crackling_fire 40
crying_baby 40
dog 40
helicopter 40
rain 40
rooster 38
sea_waves 40
sneezing 40
determine the total number of classes.
numclasses = height(labeltable);
call to split the data set into train and validation sets. inspect the distribution of labels in the training and validation sets.
[adstrain, adsvalidation] = spliteachlabel(ads,0.8); counteachlabel(adstrain)
ans=10×2 table
label count
______________ _____
chainsaw 32
clock_tick 32
crackling_fire 32
crying_baby 32
dog 32
helicopter 32
rain 32
rooster 30
sea_waves 32
sneezing 32
counteachlabel(adsvalidation)
ans=10×2 table
label count
______________ _____
chainsaw 8
clock_tick 8
crackling_fire 8
crying_baby 8
dog 8
helicopter 8
rain 8
rooster 8
sea_waves 8
sneezing 8
the vggish network expects audio to be preprocessed into log mel spectrograms. use to extract the spectrograms from the train set. there are multiple spectrograms for each audio signal. replicate the labels so that they are in one-to-one correspondence with the spectrograms.
overlappercentage = 75; trainfeatures = []; trainlabels = []; while hasdata(adstrain) [audioin,fileinfo] = read(adstrain); features = vggishpreprocess(audioin,fileinfo.samplerate,overlappercentage=overlappercentage); numspectrograms = size(features,4); trainfeatures = cat(4,trainfeatures,features); trainlabels = cat(2,trainlabels,repelem(fileinfo.label,numspectrograms)); end
extract spectrograms from the validation set and replicate the labels.
validationfeatures = []; validationlabels = []; segmentsperfile = zeros(numel(adsvalidation.files), 1); idx = 1; while hasdata(adsvalidation) [audioin,fileinfo] = read(adsvalidation); features = vggishpreprocess(audioin,fileinfo.samplerate,overlappercentage=overlappercentage); numspectrograms = size(features,4); validationfeatures = cat(4,validationfeatures,features); validationlabels = cat(2,validationlabels,repelem(fileinfo.label,numspectrograms)); segmentsperfile(idx) = numspectrograms; idx = idx 1; end
load the vggish model and convert it to a (deep learning toolbox) object.
net = vggish; lgraph = layergraph(net.layers);
use (deep learning toolbox) to remove the final regression output layer from the graph. after you remove the regression layer, the new final layer of the graph is a relu layer named 'embeddingbatch'
.
lgraph = removelayers(lgraph,"regressionoutput");
lgraph.layers(end)
ans = relulayer with properties: name: 'embeddingbatch'
use (deep learning toolbox) to add a fullyconnectedlayer
(deep learning toolbox), a (deep learning toolbox), and a (deep learning toolbox) to the graph. set the weightlearnratefactor
and biaslearnratefactor
of the new fully connected layer to 10 so that learning is faster in the new layer than in the transferred layers.
lgraph = addlayers(lgraph,[ ... fullyconnectedlayer(numclasses,name="fcfinal",weightlearnratefactor=10,biaslearnratefactor=10) softmaxlayer(name="softmax") classificationlayer(name="classout")]);
use (deep learning toolbox) to append the fully connected, softmax, and classification layers to the layer graph.
lgraph = connectlayers(lgraph,"embeddingbatch","fcfinal");
to define training options, use (deep learning toolbox).
minibatchsize = 128; options = trainingoptions("adam", ... maxepochs=5, ... minibatchsize=minibatchsize, ... shuffle="every-epoch", ... validationdata={validationfeatures,validationlabels}, ... validationfrequency=50, ... learnrateschedule="piecewise", ... learnratedropfactor=0.5, ... learnratedropperiod=2, ... outputnetwork="best-validation-loss", ... verbose=false, ... plots="training-progress");
to train the network, use (deep learning toolbox).
[trainednet, netinfo] = trainnetwork(trainfeatures,trainlabels,lgraph,options);
each audio file was split into several segments to feed into the vggish network. combine the predictions for each file in the validation set using a majority-rule decision.
validationpredictions = classify(trainednet,validationfeatures); idx = 1; validationpredictionsperfile = categorical; for ii = 1:numel(adsvalidation.files) validationpredictionsperfile(ii,1) = mode(validationpredictions(idx:idx segmentsperfile(ii)-1)); idx = idx segmentsperfile(ii); end
use (deep learning toolbox) to evaluate the performance of the network on the validation set.
figure(units="normalized",position=[0.2 0.2 0.5 0.5]); confusionchart(adsvalidation.labels,validationpredictionsperfile, ... title=sprintf("confusion matrix for validation data \naccuracy = %0.2f %%",mean(validationpredictionsperfile==adsvalidation.labels)*100), ... columnsummary="column-normalized", ... rowsummary="row-normalized")
output arguments
net
— pretrained vggish neural network
seriesnetwork
object
pretrained vggish neural network, returned as a (deep learning toolbox) object.
references
[1] gemmeke, jort f., daniel p. w. ellis, dylan freedman, aren jansen, wade lawrence, r. channing moore, manoj plakal, and marvin ritter. 2017. “audio set: an ontology and human-labeled dataset for audio events.” in 2017 ieee international conference on acoustics, speech and signal processing (icassp), 776–80. new orleans, la: ieee. https://doi.org/10.1109/icassp.2017.7952261.
[2] hershey, shawn, sourish chaudhuri, daniel p. w. ellis, jort f. gemmeke, aren jansen, r. channing moore, manoj plakal, et al. 2017. “cnn architectures for large-scale audio classification.” in 2017 ieee international conference on acoustics, speech and signal processing (icassp), 131–35. new orleans, la: ieee. https://doi.org/10.1109/icassp.2017.7952132.
extended capabilities
c/c code generation
generate c and c code using matlab® coder™.
usage notes and limitations:
only the
activations
andpredict
object functions are supported.to create a
seriesnetwork
object for code generation, see (matlab coder).
gpu code generation
generate cuda® code for nvidia® gpus using gpu coder™.
usage notes and limitations:
only the
activations
,classify
,predict
,predictandupdatestate
, andresetstate
object functions are supported.to create a
seriesnetwork
object for code generation, see (gpu coder).
version history
introduced in r2020b
打开示例
您曾对此示例进行过修改。是否要打开带有您的编辑的示例?
matlab 命令
您点击的链接对应于以下 matlab 命令:
请在 matlab 命令行窗口中直接输入以执行命令。web 浏览器不支持 matlab 命令。
select a web site
choose a web site to get translated content where available and see local events and offers. based on your location, we recommend that you select: .
you can also select a web site from the following list:
how to get best site performance
select the china site (in chinese or english) for best site performance. other mathworks country sites are not optimized for visits from your location.
americas
- (español)
- (english)
- (english)
europe
- (english)
- (english)
- (deutsch)
- (español)
- (english)
- (français)
- (english)
- (italiano)
- (english)
- (english)
- (english)
- (deutsch)
- (english)
- (english)
- switzerland
- (english)
asia pacific
- (english)
- (english)
- (english)
- 中国
- (日本語)
- (한국어)