Tensorflow: _variable_with_weight_decay (…) explicación

En este momento estoy mirando el ejemplo de cifar10 y noté la función _variable_with_weight_decay (…) en el archivo cifar10.py . El código es el siguiente:

def _variable_with_weight_decay(name, shape, stddev, wd): """Helper to create an initialized Variable with weight decay. Note that the Variable is initialized with a truncated normal distribution. A weight decay is added only if one is specified. Args: name: name of the variable shape: list of ints stddev: standard deviation of a truncated Gaussian wd: add L2Loss weight decay multiplied by this float. If None, weight decay is not added for this Variable. Returns: Variable Tensor """ dtype = tf.float16 if FLAGS.use_fp16 else tf.float32 var = _variable_on_cpu( name, shape, tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) if wd is not None: weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss') tf.add_to_collection('losses', weight_decay) return var 

Me pregunto si esta función hace lo que dice. Está claro que cuando se da un factor de caída de peso (no d ninguno), se calcula el valor de la variación (valor de peso). ¿Pero es todo aplicado? Al final, la variable no modificada (var) se devuelve, ¿o me falta algo?

¿La segunda pregunta sería cómo arreglar esto? Como entiendo, el valor del peso escalar_decio debe ser restado de cada elemento en la matriz de peso, pero no puedo encontrar una operación de tensor de flujo que pueda hacer eso (sumr / restar un solo valor de cada elemento de un tensor). ¿Hay alguna operación como esta? Como solución alternativa, pensé que podría ser posible crear un nuevo tensor inicializado con el valor de weight_decay y usar tf.subtract (…) para lograr el mismo resultado. ¿O es esta la manera correcta de ir de todos modos?

Gracias por adelantado.

El código hace lo que dice. Se supone que debe sumr todo en la colección de 'losses' (a la que se agrega el término de caída de peso en la segunda a la última línea) por la pérdida que pasa al optimizador. En la función loss() en ese ejemplo:

 tf.add_to_collection('losses', cross_entropy_mean) [...] return tf.add_n(tf.get_collection('losses'), name='total_loss') 

así que lo que devuelve la función loss() es la pérdida de clasificación más todo lo que estaba antes en la colección de 'losses' .

Como nota al margen, la caída de peso no significa que reste el valor de wd de cada valor del tensor como parte del paso de actualización, sino que multiplica el valor por (1-learning_rate*wd) (en SGD simple). Para ver por qué esto es así, recuerde que l2_loss calcula

 output = sum(t_i ** 2) / 2 

siendo t_i los elementos del tensor. Esto significa que la derivada de l2_loss con respecto a cada elemento tensorial es el valor de ese elemento tensor en sí, y dado que ha escalado l2_loss con wd la derivada también se escala.

Dado que el paso de actualización (nuevamente, en formato simple) es (perdóname por omitir los índices de paso de tiempo)

 w := w - learning_rate * dL/dw 

obtienes, si solo tuvieras el término de caída de peso

 w := w - learning_rate * wd * w 

o

 w := w * (1 - learning_rate * wd)