Personalice la función de pérdida de Keras de manera que y_true dependa de y_pred

Estoy trabajando en un clasificador multi-etiqueta. Tengo muchas tags de salida [1, 0, 0, 1 …] donde 1 indica que la entrada pertenece a esa etiqueta y 0 significa lo contrario.

En mi caso, la función de pérdida que uso está basada en MSE. Quiero cambiar la función de pérdida de forma que cuando la etiqueta de salida sea -1, cambie la probabilidad prevista de esta etiqueta.

Verifique las imágenes adjuntas para comprender mejor lo que quiero decir: el escenario es: cuando la etiqueta de salida es -1, quiero que el MSE sea igual a cero:

Este es el escenario: introduzca la descripción de la imagen aquí

Y en tal caso quiero que cambie a:

introduzca la descripción de la imagen aquí

En tal caso, el MSE de la segunda etiqueta (la salida central) será cero (este es un caso especial en el que no quiero que el clasificador aprenda sobre esta etiqueta).

Parece que este es un método necesario y realmente no creo que sea el primero en pensarlo, por lo que en primer lugar quería saber si existe un nombre para este tipo de entrenamiento en Neural Net y, en segundo lugar, me gustaría saber cómo puedo hacerlo.

Entiendo que necesito cambiar algunas cosas en la función de pérdida, pero realmente soy un novato en Theano y no estoy seguro de cómo buscar un valor específico y cómo cambiar el contenido del tensor.

Creo que esto es lo que buscas.

import theano from keras import backend as K from keras.layers import Dense from keras.models import Sequential def customized_loss(y_true, y_pred): loss = K.switch(K.equal(y_true, -1), 0, K.square(y_true-y_pred)) return K.sum(loss) if __name__ == '__main__': model = Sequential([ Dense(3, input_shape=(4,)) ]) model.compile(loss=customized_loss, optimizer='sgd') import numpy as np x = np.random.random((1, 4)) y = np.array([[1,-1,0]]) output = model.predict(x) print output # [[ 0.47242549 -0.45106074 0.13912249]] print model.evaluate(x, y) # keras's loss # 0.297689884901 print (output[0, 0]-1)**2 + 0 +(output[0, 2]-0)**2 # double-check # 0.297689929093