¿Cómo usar tf.contrib.model_pruning en MNIST?

Estoy luchando para usar la biblioteca de poda de Tensorflow y no he encontrado muchos ejemplos útiles, así que estoy buscando ayuda para podar un modelo simple entrenado en el conjunto de datos MNIST. Si alguien puede ayudar a arreglar mi bash o proporcionar un ejemplo de cómo usar la biblioteca en MNIST, se lo agradecería mucho.

La primera mitad de mi código es bastante estándar, excepto que mi modelo tiene 2 capas ocultas de 300 unidades de ancho utilizando layers.masked_fully_connected para la poda.

 import tensorflow as tf from tensorflow.contrib.model_pruning.python import pruning from tensorflow.contrib.model_pruning.python.layers import layers from tensorflow.examples.tutorials.mnist import input_data # Import dataset mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # Define Placeholders image = tf.placeholder(tf.float32, [None, 784]) label = tf.placeholder(tf.float32, [None, 10]) # Define the model layer1 = layers.masked_fully_connected(image, 300) layer2 = layers.masked_fully_connected(layer1, 300) logits = tf.contrib.layers.fully_connected(layer2, 10, tf.nn.relu) # Loss function loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label)) # Training op train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss) # Accuracy ops correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 

Luego bash definir las operaciones de poda necesarias, pero aparece un error.

 ############ Pruning Operations ############## # Create global step variable global_step = tf.contrib.framework.get_or_create_global_step() # Create a pruning object using the pruning specification pruning_hparams = pruning.get_pruning_hparams() p = pruning.Pruning(pruning_hparams, global_step=global_step) # Mask Update op mask_update_op = p.conditional_mask_update_op() # Set up the specification for model pruning prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op) 

Error en esta línea:

prune_train = tf.contrib.model_pruning.train(train_op=train_op, logdir=None, mask_update_op=mask_update_op)

InvalidArgumentError (ver más arriba para el rastreo): debe ingresar un valor para el marcador de posición ‘Placeholder_1’ con dtype float y shape [?, 10] [[Node: Placeholder_1 = Placeholderdtype = DT_FLOAT, shape = [? 10], _device = ” / job: localhost / replica: 0 / task: 0 / device: GPU: 0 “]] [[Node: global_step / _57 = _Recv_start_time = 0, client_terminated = false, recv_device =” / job: localhost / réplica: 0 / tarea : 0 / device: CPU: 0 “, send_device =” / job: localhost / replica: 0 / task: 0 / device: GPU: 0 “, send_device_incarnation = 1, tensor_name =” edge_71_global_step “, tensor_type = DT_INT64, _device =” / job: localhost / replica: 0 / task: 0 / device: CPU: 0 “]]

Supongo que quiere un tipo diferente de operación en lugar de train_op pero no he encontrado ningún ajuste que funcione.

Nuevamente, si tiene un ejemplo de trabajo diferente que pode un modelo entrenado en MNIST, lo consideraría una respuesta.

El ejemplo de biblioteca de poda más simple en el que podría trabajar, pensé que lo publicaría aquí en caso de que ayude a algún otro noobie que tenga dificultades con la documentación.

 import tensorflow as tf from tensorflow.contrib.model_pruning.python import pruning from tensorflow.contrib.model_pruning.python.layers import layers from tensorflow.examples.tutorials.mnist import input_data epochs = 250 batch_size = 55000 # Entire training set # Import dataset mnist = input_data.read_data_sets('MNIST_data', one_hot=True) batches = int(len(mnist.train.images) / batch_size) # Define Placeholders image = tf.placeholder(tf.float32, [None, 784]) label = tf.placeholder(tf.float32, [None, 10]) # Define the model layer1 = layers.masked_fully_connected(image, 300) layer2 = layers.masked_fully_connected(layer1, 300) logits = layers.masked_fully_connected(layer2, 10) # Create global step variable (needed for pruning) global_step = tf.train.get_or_create_global_step() reset_global_step_op = tf.assign(global_step, 0) # Loss function loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label)) # Training op, the global step is critical here, make sure it matches the one used in pruning later # running this operation increments the global_step train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step) # Accuracy ops correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # Get, Print, and Edit Pruning Hyperparameters pruning_hparams = pruning.get_pruning_hparams() print("Pruning Hyperparameters:", pruning_hparams) # Change hyperparameters to meet our needs pruning_hparams.begin_pruning_step = 0 pruning_hparams.end_pruning_step = 250 pruning_hparams.pruning_frequency = 1 pruning_hparams.sparsity_function_end_step = 250 pruning_hparams.target_sparsity = .9 # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9) prune_op = p.conditional_mask_update_op() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) # Train the model before pruning (optional) for epoch in range(epochs): for batch in range(batches): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys}) # Calculate Test Accuracy every 10 epochs if epoch % 10 == 0: acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print)) acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Pre-Pruning accuracy:", acc_print) print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity())) # Reset the global step counter and begin pruning sess.run(reset_global_step_op) for epoch in range(epochs): for batch in range(batches): batch_xs, batch_ys = mnist.train.next_batch(batch_size) # Prune and retrain sess.run(prune_op) sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys}) # Calculate Test Accuracy every 10 epochs if epoch % 10 == 0: acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Pruned model step %d test accuracy %g" % (epoch, acc_print)) print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity())) # Print final accuracy acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Final accuracy:", acc_print) print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity())) 

Roman Nikishin solicitó un código que podría salvar al modelo, es una ligera extensión de mi respuesta original.

 import tensorflow as tf from tensorflow.contrib.model_pruning.python import pruning from tensorflow.contrib.model_pruning.python.layers import layers from tensorflow.examples.tutorials.mnist import input_data epochs = 250 batch_size = 55000 # Entire training set model_path_unpruned = "Model_Saves/Unpruned.ckpt" model_path_pruned = "Model_Saves/Pruned.ckpt" # Import dataset mnist = input_data.read_data_sets('MNIST_data', one_hot=True) batches = int(len(mnist.train.images) / batch_size) # Define Placeholders image = tf.placeholder(tf.float32, [None, 784]) label = tf.placeholder(tf.float32, [None, 10]) # Define the model layer1 = layers.masked_fully_connected(image, 300) layer2 = layers.masked_fully_connected(layer1, 300) logits = layers.masked_fully_connected(layer2, 10) # Create global step variable (needed for pruning) global_step = tf.train.get_or_create_global_step() reset_global_step_op = tf.assign(global_step, 0) # Loss function loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label)) # Training op, the global step is critical here, make sure it matches the one used in pruning later # running this operation increments the global_step train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step) # Accuracy ops correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) # Get, Print, and Edit Pruning Hyperparameters pruning_hparams = pruning.get_pruning_hparams() print("Pruning Hyperparameters:", pruning_hparams) # Change hyperparameters to meet our needs pruning_hparams.begin_pruning_step = 0 pruning_hparams.end_pruning_step = 250 pruning_hparams.pruning_frequency = 1 pruning_hparams.sparsity_function_end_step = 250 pruning_hparams.target_sparsity = .9 # Create a pruning object using the pruning specification, sparsity seems to have priority over the hparam p = pruning.Pruning(pruning_hparams, global_step=global_step, sparsity=.9) prune_op = p.conditional_mask_update_op() # Create a saver for writing training checkpoints. saver = tf.train.Saver() with tf.Session() as sess: # Uncomment the following if you don't have a trained model yet sess.run(tf.initialize_all_variables()) # Train the model before pruning (optional) for epoch in range(epochs): for batch in range(batches): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys}) # Calculate Test Accuracy every 10 epochs if epoch % 10 == 0: acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Un-pruned model step %d test accuracy %g" % (epoch, acc_print)) acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Pre-Pruning accuracy:", acc_print) print("Sparsity of layers (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity())) # Saves the model before pruning saver.save(sess, model_path_unpruned) # Resets the session and restres the saved model sess.run(tf.initialize_all_variables()) saver.restre(sess, model_path_unpruned) # Reset the global step counter and begin pruning sess.run(reset_global_step_op) for epoch in range(epochs): for batch in range(batches): batch_xs, batch_ys = mnist.train.next_batch(batch_size) # Prune and retrain sess.run(prune_op) sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys}) # Calculate Test Accuracy every 10 epochs if epoch % 10 == 0: acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Pruned model step %d test accuracy %g" % (epoch, acc_print)) print("Weight sparsities:", sess.run(tf.contrib.model_pruning.get_weight_sparsity())) # Saves the model after pruning saver.save(sess, model_path_pruned) # Print final accuracy acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels}) print("Final accuracy:", acc_print) print("Final sparsity by layer (should be 0)", sess.run(tf.contrib.model_pruning.get_weight_sparsity()))