'TensorFlow - Recognize specific number from MNIST (such as a 7)
I've just started learning about ML and TensorFlow, and the first example I picked up was the number recognition ( https://www.tensorflow.org/versions/r1.0/get_started/mnist/beginners ) I've understood the concept, however I'd like to make a simple modification to it: it checks wether or not it's the number 7. Here's my current all-numbers recognizing code (python 3). the accuracy is 90%:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
import tensorflow as tf
#Inputs
x = tf.placeholder(tf.float32,[None,784])
#layer
w = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x,w)+b)
#expected result
y_ = tf.placeholder(tf.float32, [None, 10])
#cost function
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
#gradient descent
train_step = tf.train.GradientDescentOptimizer(0.05).minimize(cross_entropy)
#begin session
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
for _ in range(1000):
#utilização de batches para economia de computação (treino estocástico)
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(sess.run(accuracy, feed_dict={x: mnist.test.images,y_: mnist.test.labels}))
My idea is to reduce the output layer, so it looks like this:
x = tf.placeholder(tf.float32,[None,784])
w = tf.Variable(tf.zeros([784,1]))
b = tf.Variable(tf.zeros([1]))
y = tf.nn.softmax(tf.matmul(x,w)+b)
y_ = tf.placeholder(tf.float32, [None, 1])
so that there's only one neuron of output. My problem is my training label. I've no idea on how to train using only the 7th field of the mnist.train.label. I've dome some test like
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_[7] * tf.log(y), reduction_indices=[1]))
but it didn't work.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
