digit classification with wavelet scattering -凯发k8网页登录
this example shows how to use wavelet scattering for image classification. this example requires wavelet toolbox™, deep learning toolbox™, and parallel computing toolbox™.
for classification problems, it is often useful to map the data into some alternative representation which discards irrelevant information while retaining the discriminative properties of each class. wavelet image scattering constructs low-variance representations of images which are insensitive to translations and small deformations. because translations and small deformations in the image do not affect class membership, scattering transform coefficients provide features from which you can build robust classification models.
wavelet scattering works by cascading the image through a series of wavelet transforms, nonlinearities, and averaging [1][3][4]. the result of this deep feature extraction is that images in the same class are moved closer to each other in the scattering transform representation, while images belonging to different classes are moved farther apart. while the wavelet scattering transform has a number of architectural similarities with deep convolutional neural networks, including convolution operators, nonlinearities, and averaging, the filters in a scattering transform are pre-defined and fixed.
digit images
the dataset used in this example contains 10,000 synthetic images of digits from 0 to 9. the images are generated by applying random transformations to images of those digits created with different fonts. each digit image is 28-by-28 pixels. the dataset contains an equal number of images per category. use the imagedatastore
to read the images.
digitdatasetpath = fullfile(matlabroot,'toolbox','nnet','nndemos','nndatasets','digitdataset'); imds = imagedatastore(digitdatasetpath,'includesubfolders',true, 'labelsource','foldernames');
randomly select and plot 20 images from the dataset.
figure numimages = 10000; rng(100); perm = randperm(numimages,20); for np = 1:20 subplot(4,5,np); imshow(imds.files{perm(np)}); end
you can see that the 8s exhibit considerable variability while all being identifiable as an 8. the same is true of the other repeated digits in the sample. this is consistent with natural handwriting where any digit differs non-trivially between individuals and even within the same individual's handwriting with respect to translation, rotation, and other small deformations. using wavelet scattering, we hope to build representations of these digits which obscure this irrelevant variability.
wavelet image scattering feature extraction
the synthetic images are 28-by-28. create a wavelet image scattering framework and set the invariance scale to equal the size of the image. set the number of rotations to 8 in each of the two wavelet scattering filter banks. the construction of the wavelet scattering framework requires that we set only two hyperparameters: the invariancescale
and numrotations
.
sf = waveletscattering2('imagesize',[28 28],'invariancescale',28, ... 'numrotations',[8 8]);
this example uses matlab™'s parallel processing capability through the tall
array interface. if a parallel pool is not currently running, you can start one by executing the following code. alternatively, the first time you create a tall
array, the parallel pool is created.
if isempty(gcp) parpool; end
starting parallel pool (parpool) using the 'local' profile ... connected to the parallel pool (number of workers: 6).
for reproducibility, set the random number generator. shuffle the files of the imagedatastore
and split the 10,000 images into two sets, one for training and one held-out set for testing. allocate 80% of the data, or 8,000 images, to the training set and hold out the remaining 2,000 images for testing. create tall
arrays from the training and test datasets. use the helper function helperscatimages
to create feature vectors from the scattering transform coefficients. helperscatimages
obtains the log of the scattering transform feature matrix as well as the mean along both the row and column dimensions of each image. the code for helperscatimages
is at the end of this example. for each image in this example, the helper function creates a 217-by-1 feature vector.
rng(10); imds = shuffle(imds); [trainimds,testimds] = spliteachlabel(imds,0.8); ttrain = tall(trainimds); ttest = tall(testimds); trainfeatures = cellfun(@(x)helperscatimages(sf,x),ttrain,'uniformoutput',false); testfeatures = cellfun(@(x)helperscatimages(sf,x),ttest,'uniformoutput',false);
use tall
's gather
capability to concatenate all the training and test features.
trainf = gather(trainfeatures);
evaluating tall expression using the parallel pool 'local': - pass 1 of 1: completed in 3 min 51 sec evaluation completed in 3 min 51 sec
trainfeatures = cat(2,trainf{:}); testf = gather(testfeatures);
evaluating tall expression using the parallel pool 'local': - pass 1 of 1: completed in 49 sec evaluation completed in 49 sec
testfeatures = cat(2,testf{:});
the previous code results in two matrices with row dimensions 217 and column dimension equal to the number of images in the training and test sets respectively. accordingly, each column is a feature vector for its corresponding image. the original images contained 784 elements. the scattering coefficients represent an approximate 4-fold reduction in the size of each image.
pca model and prediction
this example constructs a simple classifier based on the principal components of the scattering feature vectors for each class. the classifier is implemented in the functions helperpcamodel
and helperpcaclassifier
. helperpcamodel
determines the principal components for each digit class based on the scattering features. the code for helperpcamodel
is at the end of this example. helperpcaclassifier
classifies the held-out test data by finding the closest match (best projection) between the principal components of each test feature vector with the training set and assigning the class accordingly. the code for helperpcaclassifier
is at the end of this example.
model = helperpcamodel(trainfeatures,30,trainimds.labels); predlabels = helperpcaclassifier(testfeatures,model);
after constructing the model and classifying the test set, determine the accuracy of the test set classification.
accuracy = sum(testimds.labels == predlabels)./numel(testimds.labels)*100
accuracy = 99.6000
we have achieved 99.6% correct classification of the test data. to see how the 2,000 test images have been classified, plot the confusion matrix. there are 200 examples in the test set for each of the 10 classes.
figure;
confusionchart(testimds.labels,predlabels)
title('test-set confusion matrix -- wavelet scattering')
cnn
in this section, we train a simple convolutional neural network (cnn) to recognize digits. construct the cnn to consist of a convolution layer with 20 5-by-5 filters with 1-by-1 strides. follow the convolution layer with a relu activation and max pooling layer. use a fully connected layer, followed by a softmax layer to normalize the output of the fully connected layer to probabilities. use a cross entropy loss function for learning.
imagesize = [28 28 1]; layers = [ ... imageinputlayer([28 28 1]) convolution2dlayer(5,20) relulayer maxpooling2dlayer(2,'stride',2) fullyconnectedlayer(10) softmaxlayer classificationlayer];
use stochastic gradient descent with momentum and a learning rate of 0.0001 for training. set the maximum number of epochs to 20. for reproducibility, set the executionenvironment
to 'cpu'
.
options = trainingoptions('sgdm', ... 'maxepochs',20,... 'initiallearnrate',1e-4, ... 'verbose',false, ... 'plots','training-progress','executionenvironment','cpu');
train the network. for training and testing we use the same data sets used in the scattering transform.
reset(trainimds); reset(testimds); net = trainnetwork(trainimds,layers,options);
by the end of training, the cnn is performing near 100% on the training set. use the trained network to make predictions on the held-out test set.
ypred = classify(net,testimds,'executionenvironment','cpu'); dcnnaccuracy = sum(ypred == testimds.labels)/numel(ypred)*100
dcnnaccuracy = 95.5000
the simple cnn has achieved 95.5% correct classification on the held-out test set. plot the confusion chart for the cnn.
figure;
confusionchart(testimds.labels,ypred)
title('test-set confusion chart -- cnn')
summary
this example used wavelet image scattering to create low-variance representations of digit images for classification. using the scattering transform with fixed filter weights and a simple principal components classifier, we achieved 99.6% correct classification on a held-out test set. with a simple cnn in which the filters are learned, we achieved 95.5% correct. this example is not intended as a direct comparison of the scattering transform and cnns. there are multiple hyperparameter and architectural changes you can make in each case, which significantly affect the results. the goal of this example was simply to demonstrate the potential of deep feature extractors like the wavelet scattering transform to produce robust representations of data for learning.
references
[1] bruna, j., and s. mallat. "invariant scattering convolution networks." ieee transactions on pattern analysis and machine intelligence. vol. 35, number 8, 2013, pp. 1872–1886.
[2] mallat, s. "group invariant scattering." communications in pure and applied mathematics. vol. 65, number 10, 2012, pp. 1331–1398.
[3] sifre, l., and s. mallat. "rotation, scaling and deformation invariant scattering for texture discrimination." 2013 ieee conference on computer vision and pattern recognition. 2013, pp 1233–1240. 10.1109/cvpr.2013.163.
appendix — supporting functions
helperscatimages
function features = helperscatimages(sf,x) % this function is only to support examples in the wavelet toolbox. % it may change or be removed in a future release. % 凯发官网入口首页 copyright 2018 mathworks smat = featurematrix(sf,x,'transform','log'); features = mean(mean(smat,2),3); end
helperpcamodel
function model = helperpcamodel(features,m,labels) % this function is only to support wavelet image scattering examples in % wavelet toolbox. it may change or be removed in a future release. % model = helperpcamodel(features,m,labels) % 凯发官网入口首页 copyright 2018 mathworks % initialize structure array to hold the affine model model = struct('dim',[],'mu',[],'u',[],'labels',categorical([]),'s',[]); model.dim = m; % obtain the number of classes labelcategories = categories(labels); nclasses = numel(categories(labels)); for kk = 1:nclasses class = labelcategories{kk}; % find indices corresponding to each class idxclass = labels == class; % extract feature vectors for each class tmpfeatures = features(:,idxclass); % determine the mean for each class model.mu{kk} = mean(tmpfeatures,2); [model.u{kk},model.s{kk}] = scatpca(tmpfeatures); if size(model.u{kk},2) > m model.u{kk} = model.u{kk}(:,1:m); model.s{kk} = model.s{kk}(1:m); end model.labels(kk) = class; end function [u,s,v] = scatpca(x,m) % calculate the principal components of x along the second dimension. if nargin > 1 && m > 0 % if m is non-zero, calculate the first m principal components. [u,s,v] = svds(x-sig_mean(x),m); s = abs(diag(s)/sqrt(size(x,2)-1)).^2; else % otherwise, calculate all the principal components. % each row is an observation, i.e. the number of scattering paths % each column is a class observation [u,d] = eig(cov(x')); [s,ind] = sort(diag(d),'descend'); u = u(:,ind); end end end
helperpcaclassifier
function labels = helperpcaclassifier(features,model) % this function is only to support wavelet image scattering examples in % wavelet toolbox. it may change or be removed in a future release. % model is a structure array with fields, m, mu, v, and labels % features is the matrix of test data which is ns-by-l, ns is the number of % scattering paths and l is the number of test examples. each column of % features is a test example. % 凯发官网入口首页 copyright 2018 mathworks labelidx = determineclass(features,model); labels = model.labels(labelidx); % returns as column vector to agree with imagedatastore labels labels = labels(:); %-------------------------------------------------------------------------- function labelidx = determineclass(features,model) % determine number of classes nclasses = numel(model.labels); % initialize error matrix errmatrix = inf(nclasses,size(features,2)); for nc = 1:nclasses % class centroid mu = model.mu{nc}; u = model.u{nc}; % 1-by-l errmatrix(nc,:) = projectionerror(features,mu,u); end % determine minimum along class dimension [~,labelidx] = min(errmatrix,[],1); %-------------------------------------------------------------------------- function totalerr = projectionerror(features,mu,u) % npc = size(u,2); l = size(features,2); % subtract class mean: ns-by-l minus ns-by-1 s = features-mu; % 1-by-l normsqx = sum(abs(s).^2,1)'; err = inf(npc 1,l); err(1,:) = normsqx; err(2:end,:) = -abs(u'*s).^2; % 1-by-l totalerr = sqrt(sum(err,1)); end end end