Example: Formal Verification of MNIST Images
The following example shows how we can formally verify a neural network given uncertainty in the input.
Let us first load the dataset and the neural network into MATLAB:
% load MNIST dataset (required Deep Learning Toolbox)
[XTest,YTest] = digitTest4DArrayData;
modelfile = "mnist_sigmoid_6_200.onnx";
nn = neuralNetwork.readONNXNetwork(modelfile, false, 'BCSS');
Next, we select an image and test if the network is classifying it correctly:
x = XTest(:,:,:,idx); % image
target = double(YTest(idx))-1; % label (dim=1 is label=0)
% reshape image to vector to be consistent with set definitions
% propagate through network
[~,label_pred] = max(y_pred);
label_pred = label_pred - 1;
fprintf("Correct classification: %d
", target == label_pred);
Correct classification: 1
handleVis = ["on","off"];
scatter(y_pred(label+1,:),label,'.k','DisplayName','Sample','HandleVisibility',handleVis((label > 0) + 1))
plot(ones(2,1).* y_pred(target+1,:),[-1,10],'--','Color',CORAcolor('CORA:safe'),'DisplayName','Classification');
xlim([-15,5]); ylim([-1,10]); yticks(labels);
ylabel('Label'); xlabel('Prediction');
As the prediction of the correct label is larger than all other predictions, the image is correctly classified by the neural network.
However, this approach no longer works if the input is uncertain, i.e. each pixel can be perturbed up to a perturbation radius ϵ. Let us demonstrate this by the following example:
X = interval(c-epsilon,c+epsilon);
subplot(2,2,1); imshow(reshape(c,28,28));
subplot(4,4,idx(i)); imshow(reshape(xs(:,i),28,28));
% reduce uncertain input for verification
X = interval(c-epsilon,c+epsilon);
S = nn.calcSensitivity(c);
fgsm = sign(mean(sign(S)))' * epsilon;
% propagate samples through neural network
ys_pred = nn.evaluate(xs);
handleVis = ["on","off"];
scatter(ys_pred(label+1,:),label .* ones(N,1),'.k','DisplayName','Samples','HandleVisibility',handleVis((label > 0) + 1))
plot(min(ys_pred(target+1,:)) .* ones(2,1),[-1,10],'--','Color',CORAcolor('CORA:safe'),'DisplayName','Verified?');
xlim([-15,5]); ylim([-1,10]); yticks(labels);
ylabel('Label'); xlabel('Prediction');
Unfortunately, we cannot reason about the correct classification of the entire input set by just looking at samples as we might miss outliers.
Thus, we can conservatively propagate the set itself through the network using CORA. If the lower bound of the target label is larger than the upper bound of all other labels, we have formally proven that all images are classified correctly given the noise radius ϵ.
% propagate uncertain input set through the network
Y_pred = nn.evaluate(zonotope(X));
handleVis = ["on","off"];
scatter(ys_pred(label+1,:),label .* ones(N,1),'.k','DisplayName','Samples','HandleVisibility',handleVis((label > 0) + 1))
plot(Y_pred,label+1,'-|','YPos',label,'DisplayName','Bounds','HandleVisibility',handleVis((label > 0) + 1),'Color',CORAcolor('CORA:reachSet'))
plot(interval(Y_pred).inf(target+1,:) .* ones(2,1),[-1,10],'--','Color',CORAcolor('CORA:safe'),'DisplayName','Verified!');
xlim([-15,5]); ylim([-1,10]); yticks(labels);
ylabel('Label'); xlabel('Prediction');
Finally, as the notion of the lower bound of the target label being larger than the upper bound of all other labels is a bit tedious, we can transform the output space to have the target output always at 0 and all other bounds should be below zero. This allows a simple verification check algorithmically and visually as given below:
% init transformation matrix
T(:,target+1) = T(:,target+1)-1;
% define unsafe set as specification
spec = specification(polytope(-eye(10),zeros(10,1)),'unsafeSet');
fprintf("Check samples: %d
", check(spec,ys_pred));
fprintf("Verification: %d
", check(spec,Y_pred));
xlim([-15,5]); ylim([-1,10]); yticks(labels);
plot(specification(polytope([-1,0],0),'unsafeSet'),1:2,'DisplayName','Unsafe set')
handleVis = ["on","off"];
scatter(ys_pred(label+1,:),label .* ones(N,1),'.k','DisplayName','Samples','HandleVisibility',handleVis((label > 0) + 1))
plot(Y_pred,label+1,'-|','YPos',label,'DisplayName','Bounds','HandleVisibility',handleVis((label > 0) + 1),'Color',CORAcolor('CORA:reachSet'))
xlim([-15,5]); ylim([-1,10]); yticks(labels);
ylabel('Label'); xlabel('Prediction');