top of page

Training MNIST dataset by TensorFlow

  • Glen
  • Apr 11, 2017
  • 2 min read

Use TensorFlow to do training and testing in MNIST dataset.

1.MNIST Data

This database is a large database of handwritten digits that is commonly used for training various image processing systems. This database contains 60,000 training images (mnist.train) and 10,000 testing images (mnist.test).

28*28 pixels in one image, we can use 28*28 = 784 dimensions vector to present this matrix.

Mnist.train.xs represents 60000 training images.

Mnist.train.ys represents the label of the 60000 image. There’re 10 labels from 0 to 9. Each label is the real number shown in each image.

2. Softmax Regression

The model of learning is:

A example learned model: The weights of each pixel towards each label . blue is 1, red is 0.

Cost Function (Loss Function): Cross-entropy

Training Algorithm:

Use Gradient Descent algorithm, which is backpropagation algorithm.

The backpropagation algorithm looks for the minimum of the error function in weight space using the method of gradient descent.

3. Use TensorFlow to train the model

Step1: Initialize & start the model

W = tf.Variable(tf.zeros([784,10]))

b = tf.Variable(tf.zeros([10]))

变量需要通过session初始化后,才能在session中使用。这一初始化步骤为,为初始值指定具体值(本例当中是全为零),并将其分配给每个变量,可以一次性为所有变量完成此操作。

init = tf.initialize_all_variables()

sess = tf.Session()

sess.run(init)

Step2: Training the model (Optimize the weights)

for i in range(1000):

batch_xs, batch_ys = mnist.train.next_batch(100) // Stochastic training: Randomly use 100 data to train the model.

sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})//feed_dict将x 和 y_张量占位符用训练数据替代。

And the train_step is:

train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

Step3: Evaluate the model

correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) //y_ is the correct label

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) // The accuracy of the model

print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}) // Print out the accuracy


Comentários


San Diego State University

Computer Vision Lab

© 2023 by Scientist Personal. Proudly created with Wix.com

  • Octocat
  • LinkedIn Social Icon
bottom of page