No se puede asignar valor a la variable tensor cargada desde el gráfico

He entrenado un modelo y lo he guardado. Ahora, estoy tratando de ver cómo las perturbaciones de los pesos pueden afectar su precisión, así que necesito modificar los valores guardados en mis variables de pesos, esencialmente agregándole algo de ruido. El problema es que no puedo asignarles un valor después de haberlos cargado. Estoy usando la versión 1.2.1 de tensorflow, para entrenar y cargar el modelo. Aquí está mi código:

import tensorflow as tf tf.reset_default_graph() sess = tf.InteractiveSession() saver = tf.train.import_meta_graph('/scratch/pedro/TFModels/Checks_and_Logs/20170803_215828/beta_model-1.meta') print("Graph restred") saver.restre(sess, tf.train.latest_checkpoint('/scratch/pedro/TFModels/Checks_and_Logs/20170803_215828/')) print("Model restred") tf.global_variables() #prints the list of variables in the graph 

Esto produce el siguiente resultado:

 [, , , , , , , , , , , , , , , , , , , , , , , , , , ] 

Entonces, he estado intentando modificar el primero (FF_NN / Model / hidden_layer_1 / weight / Variable: 0) pero eso me da un error:

 x = data_train[:batch_size] y = data_train_labels[:batch_size] graph = tf.get_default_graph() data_train_tensor = graph.get_tensor_by_name("Train_Dataset:0") data_train_labels_onehot = graph.get_tensor_by_name("Train_Labels:0") acc_te = graph.get_tensor_by_name("Test_Data_Accuracy/Mean:0") acc_tr = graph.get_tensor_by_name("Train_Data_Accuracy/Mean:0") w1 = graph.get_tensor_by_name("FF_NN/Model/hidden_layer_1/weights/Variable:0") print('w1:\n', w1.eval()) training_acc, test_acc = sess.run([acc_tr, acc_te], feed_dict={data_train_tensor: x, data_train_labels_onehot: y}) print(test_acc) w1 = w1 + 50 print('w1:\n', w1.eval()) sess.run(w1.assign(w1)) training_acc, test_acc, _ = sess.run([acc_tr, acc_te, w1], feed_dict={data_train_tensor: x, data_train_labels_onehot: y}) print(test_acc) 

Esto me da un error en la operación de asignación:

 w1: [[-0.0531723 0.73768502 0.14098917 ..., 1.67111528 0.2495033 0.20415793] [ 1.20964873 -0.99254322 -3.01407313 ..., 0.40427083 0.33289135 0.2326804 ] [ 0.70157909 -1.61257529 -0.59762233 ..., 0.20860809 -0.02733657 1.57942903] ..., [ 1.23854971 -2.28062844 -1.01647282 ..., 1.18426156 0.65342903 -0.45519635] [ 1.02164841 -0.11143603 1.71673298 ..., -0.85511237 1.15535712 0.50917912] [-2.52524352 -0.04488864 0.66239733 ..., -0.45516238 -0.76003599 -1.2073245 ]] 0.242335 w1: [[ 49.94682693 50.73768616 50.1409874 ..., 51.67111588 50.24950409 50.20415878] [ 51.20964813 49.00745773 46.98592758 ..., 50.40427017 50.33288956 50.23268127] [ 50.70158005 48.38742447 49.40237808 ..., 50.20860672 49.97266388 51.57942963] ..., [ 51.23854828 47.7193718 48.98352814 ..., 51.18426132 50.65342712 49.54480362] [ 51.02164841 49.88856506 51.71673203 ..., 49.14488602 51.15535736 50.50917816] [ 47.47475815 49.95511246 50.66239548 ..., 49.54483795 49.23996353 48.79267502]] --------------------------------------------------------------------------- AttributeError Traceback (most recent call last)  in () 16 w1 = w1 +50 17 print('w1:\n', w1.eval()) ---> 18 sess.run(w1.assign(w1)) 19 #print('w1:\n', w1.eval()) 20 training_acc, test_acc, _ = sess.run([acc_tr, acc_te, w1], feed_dict={data_train_tensor: x, data_train_labels_onehot: y}) AttributeError: 'Tensor' object has no attribute 'assign' 

Todas las preguntas similares apuntan al hecho de que w1 debería ser un tipo tf.Variable y ese parece ser el caso aquí, de acuerdo con la salida de tf.global_variables() .

El siguiente código funcionaría. mejor manera de usar get_variable

 w1 = tf.get_variable("FF_NN/Model/hidden_layer_1/weights/Variable:0") sess.run(tf.assign(w1, w1+50)) 

Por ahora, este paso no funcionará, este es un error en tensorflow https://github.com/tensorflow/tensorflow/issues/1325

Una solución de trabajo:

 w1 = [v for v in tf.global_variables() if v.name=="FF_NN/Model/hidden_layer_1/weights/Variable:0"][0] sess.run(tf.assign(w1, w1+50)) 

tf.get_variable obtener el objeto variable subyacente utilizando tf.get_variable o tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)[0]

Ishant muestra que get_variable tiene actualmente un error para las variables de ámbito, por lo que, hasta que se solucione, debe usar get_collection