# Naive Bayes Number Classification

Bayesian Classification seems a very complicated (mathematical) concept, and many of us just avoided using them simply because of that.

In reality, the method uses the simplest theorem in statistics, the Bayes Theorem which computes the conditional probability of  one event(A) dependant on another event (B):  (The demonstration is direct : The probability of event A and B  to happen simultaneously(notation Pr(A,B)) is equal to the probability of event B happening when we know A happened (notation Pr(B|A), times the probability of event A independently (notation Pr(A)), in other words, Pr(A,B)=Pr(B|A)Pr(A) . By symmetry, also  Pr(A,B)=Pr(A|B)Pr(B) . Putting these together, we obtain the Bayes Theorem above, q.e.d )

We will show in this post how we can use this theorem to implement a number classifier on the MNIST database (a database containing 60000 labeled images of hand-written digits).

Each image which represents a number in the MNIST database is a 28*28 bitmap. We have as inputs x pixels for each image, and one class (from 0 to 9) which represent the number in the image.

The main idea is to compute for each pixel in the image the probability to be found in each of the 10 classes. Then, the probability of image x being in class y is the product of the probabilities of each pixel being in that class.

p(x|y) = i p(xi|y)

Using the Bayes theorem this leads to the classifier  p(y/x) = i p(xi|y) p(y) / p(x)

The steps to implement this classifier will then be :

a) Import the pairs image /label (x /y)

b) For all images in the training set, calculate for each pixel the probability to be found in each of the 10 output classes. Bellow, we plot as heat maps the estimated probabilities for each pixel being in each class. c) For a new image x, the probability that it has the label y  is simply the product of the probability of each pixel being in that class. We just choose the label with the highest probability given x

Here are the prediction results for 10 random hand-written numbers(we see that one of the 4s and the 5 are wrongly classified): The accuracy of this method is rather small for today standards (on MNIST test the accuracy is 84.26%, on the train_set the accuracy is 83.57%), but it’s still a nice example in which we can show the power of purely statistical methods.

Bellow is the code of each of the steps described above, using numpy for computation and mxnet  only to import the data(any other way of importing the MNIST database is allowed)

a) Import the pairs image /label (x /y)

def transform(data, label):
return (nd.floor(data/128)).astype(np.float32), label.astype(np.float32)

mnist_train= mx.gluon.data.vision.MNIST(train= True, transform = transform)
mnist_test= mx.gluon.data.vision.MNIST(train= False, transform = transform)
ycount = np.ones(shape = (10))
xcount = np.ones(shape = (784,10))

b) for all images in the training set, calculate for each pixel the probability to be found in each of the 10 output classes. Bellow, we plot as heat maps the estimated probabilities for each pixel being in each class.

#agregate count statistics for how freauently a pixel is on or off for zeros and ones

for data, label in mnist_train:
x=data.reshape((784,)).asnumpy()
y = int(label)
ycount[y]=ycount[y]+1
xcount[:,y]=xcount[:,y]+x

#normalize the probabilities p(x_i/y)
for i in range(10):
xcount[:,i] = xcount[:,i]/ycount[i]
py = ycount/np.sum(ycount)

fig, figarr = plt.subplots(1,10)
for i in range(10):
figarr[i].imshow(xcount[:,i].reshape((28,28)))
plt.show()

c) for a new image x, the probability that it has the label y  is simply the product of the probability of each pixel being in that class. We just choose the label with the highest probability given x .

fig,figarr = plt.subplots(2,10,figsize = (15,3))

ctr =0
for data, label in mnist_test:
x= data.reshape((784,)).asnumpy()
y = int(label)

px = py.copy()
for i in range(10):
px[i] = np.prod(np.dot(xcount[:,i],x))
#compute softmax
px = np.exp(px)
px /= np.sum(px)

figarr[1, ctr].bar(range(10), px)
figarr[1, ctr].axes.get_yaxis().set_visible(False)
figarr[0, ctr].imshow(x.reshape((28, 28)))

ctr += 1
if ctr == 10:
break
plt.show()