Tensorflow crea nuevas variables a pesar de que se reutiliza y se establece en true

Estoy tratando de construir un RNN básico, pero recibo errores al usar la red después del entrenamiento. Tengo architecture de red en una inference función

 def inference(inp): with tf.name_scope("inference"): layer = SimpleRNN(1, activation='sigmoid', return_sequences=False)(inp) layer = Dense(1)(layer) return layer 

pero cada vez que lo llamo, se crea otro conjunto de variables a pesar de usar el mismo scope en la capacitación:

 def train(sess, seq_len=2, epochs=100): x_input, y_input = generate_data(seq_len) with tf.name_scope('train_input'): x = tf.placeholder(tf.float32, (None, seq_len, 1)) y = tf.placeholder(tf.float32, (None, 1)) with tf.variable_scope('RNN'): output = inference(x) with tf.name_scope('training'): loss = tf.losses.mean_squared_error(labels=y, predictions=output) train_op = tf.train.RMSPropOptimizer(learning_rate=0.1).minimize(loss=loss, global_step=tf.train.get_global_step()) with sess.as_default(): sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) for i in tqdm.trange(epochs): ls, res, _ = sess.run([loss, output, train_op], feed_dict={x:x_input, y:y_input}) if i%100==0: print(f'{ls}: {res[10]} - {y_input[10]}') x_input, y_input = generate_data(seq_len) 

y predicción:

 def predict_signal(sess, x, seq_len): # Preparing signal (omitted) # Predict inp = tf.convert_to_tensor(prepared_signal, tf.float32) with sess.as_default(): with tf.variable_scope('RNN', reuse=True) as scope: output = inference(inp) result = output.eval() return result 

Ya llevo un par de horas leyendo sobre los ámbitos de las variables, pero en la ejecución de la predicción aún recibo un error. Attempting to use uninitialized value RNN_1/inference/simple_rnn_2/kernel , con el número de RNN_1 aumentando con cada llamada

Esto es solo una especulación hasta que nos SimpleRNN implementación de SimpleRNN . Sin embargo, sospecho que SimpleRNN está muy mal implementado. Hay una diferencia entre tf.get_variable y tf.Variable . Espero que su SimpleRNN use tf.Variable .

Para reproducir este comportamiento:

 import tensorflow as tf def inference(x): w = tf.Variable(1., name='w') layer = x + w return layer x = tf.placeholder(tf.float32) with tf.variable_scope('RNN'): output = inference(x) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) print(sess.run(output, {x: 10})) with sess.as_default(): with tf.variable_scope('RNN', reuse=True): output2 = inference(x) print(sess.run(output2, {x: 10})) 

Esto da exactamente el mismo error:

Intentando usar un valor no inicializado RNN_1 / w

Sin embargo, la versión con w = tf.get_variable('w', initializer=1.) lugar de w = tf.Variable(1., name='w') hace que funcione.

¿Por qué? Ver los documentos:

tf.get_variable:

Obtiene una variable existente con estos parámetros o crea uno nuevo. Esta función prefija el nombre con el scope de la variable actual y realiza reutilizaciones .

Gracias por la pregunta (agregué la bandera de keras a tu pregunta). Esta se está convirtiendo en mi razón favorita para mostrar a las personas por qué usar Keras es la peor decisión que tomaron.

SimpleRNN crea sus variables aquí:

 self.kernel = self.add_weight(shape=(input_shape[-1], self.units), name='kernel',...) 

Esto ejecuta la línea.

 weight = K.variable(initializer(shape), dtype=dtype, name=name, constraint=constraint) 

que termina aquí

 v = tf.Variable(value, dtype=tf.as_dtype(dtype), name=name) 

Y esta es una falla obvia en la implementación. Hasta que Keras use TensorFlow de la manera correcta (respetando al menos los scopes y variable-collections ), debe buscar alternativas. El mejor consejo que alguien puede darte es cambiar a algo mejor como los tf.layers oficiales.

@Patwie realizó el diagnóstico correcto con respecto al error, un posible error en la implementación de Keras de referencia.

Sin embargo, en mi opinión, la conclusión lógica no es descartar a Keras, sino usar la implementación de Keras que viene con tensorflow, que se puede encontrar en tf.keras . Encontrará que las variables se generan correctamente en esta implementación. tf.keras se implementa específicamente para tensorflow y debe minimizar este tipo de error de interfaz.

De hecho, si ya está tensorflow, no veo ningún beneficio en particular al usar las Keras de referencia en lugar de tf.keras , a menos que esté usando sus funciones más recientes, tf.keras es un poco más atrasado en términos de versiones ( por ejemplo, actualmente en 2.1.5 en TF 1.8 wheras Keras 2.2.0 ha estado fuera durante aproximadamente un mes).