¿Cómo convertir una capa densa en una capa convolucional equivalente en Keras?

Me gustaría hacer algo similar al documento de Redes Totalmente Convolucionales ( https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf ) usando Keras. Tengo una red que termina aplanando los mapas de características y los ejecuta a través de varias capas densas. Me gustaría cargar los pesos de una red como esta en una en la que las capas densas se reemplacen con convoluciones equivalentes.

La red VGG16 que viene con Keras podría usarse como ejemplo, donde la salida 7x7x512 del último MaxPooling2D () se aplana y luego entra en una capa densa (4096). En este caso, el Denso (4096) sería reemplazado por una convolución de 7x7x4096.

Mi red real es ligeramente diferente, hay una capa GlobalAveragePooling2D () en lugar de MaxPooling2D () y Flatten (). La salida de GlobalAveragePooling2D () es un tensor 2D, y no es necesario aplanarlo adicionalmente, por lo que todas las capas densas, incluida la primera, se reemplazarán por convoluciones 1×1.

He visto esta pregunta: Python sabe cómo transformar una capa densa en una capa convolucional que parece muy similar, si no idéntica. El problema es que la solución sugerida no funciona, porque (a) estoy usando TensorFlow como backend, por lo que la “rotación” de reorganización de pesos / filtro no es correcta, y (b) no puedo entender cómo cargar los pesos. Cargar el archivo de pesos antiguo en la nueva red con model.load_weights(by_name=True) no funciona, porque los nombres no coinciden (e incluso si difirieron las dimensiones).

¿Qué debe ser la reorganización cuando se utiliza TensorFlow?

¿Cómo se cargan los pesos? ¿Debo crear uno de cada modelo, llamar a model.load_weights () en ambos para cargar los pesos idénticos y luego copiar algunos de los pesos adicionales que necesitan reorganización?

Basándome en la respuesta de hars, creé esta función para transformar un CNN arbitrario en un fcn:

 from keras.models import Sequential from keras.layers.convolutional import Convolution2D from keras.engine import InputLayer import keras def to_fully_conv(model): new_model = Sequential() input_layer = InputLayer(input_shape=(None, None, 3), name="input_new") new_model.add(input_layer) for layer in model.layers: if "Flatten" in str(layer): flattened_ipt = True f_dim = layer.input_shape elif "Dense" in str(layer): input_shape = layer.input_shape output_dim = layer.get_weights()[1].shape[0] W,b = layer.get_weights() if flattened_ipt: shape = (f_dim[1],f_dim[2],f_dim[3],output_dim) new_W = W.reshape(shape) new_layer = Convolution2D(output_dim, (f_dim[1],f_dim[2]), strides=(1,1), activation=layer.activation, padding='valid', weights=[new_W,b]) flattened_ipt = False else: shape = (1,1,input_shape[1],output_dim) new_W = W.reshape(shape) new_layer = Convolution2D(output_dim, (1,1), strides=(1,1), activation=layer.activation, padding='valid', weights=[new_W,b]) else: new_layer = layer new_model.add(new_layer) return new_model 

Puedes probar la función así:

 model = keras.applications.vgg16.VGG16() new_model = to_fully_conv(model) 

a. No hay necesidad de hacer una rotación complicada. Solo remodela esta funcionando

segundo. Usa get_weights () y init new layer

Iterice a través de model.layers, cree la misma capa con configuración y cargue los pesos con set_weights o como se muestra a continuación.

El siguiente pedazo de pseudo código funciona para mí. (Keras 2.0)

Pseudo Código:

 # find input dimensions of Flatten layer f_dim = flatten_layer.input_shape # Creating new Conv layer and putting dense layers weights m_layer = model.get_layer(layer.name) input_shape = m_layer.input_shape output_dim = m_layer.get_weights()[1].shape[0] W,b = layer.get_weights() if first dense layer : shape = (f_dim[1],f_dim[2],f_dim[3],output_dim) new_W = W.reshape(shape) new_layer = Convolution2D(output_dim,(f_dim[1],f_dim[2]),strides=(1,1),activation='relu',padding='valid',weights=[new_W,b]) else: (not first dense layer) shape = (1,1,input_shape[1],output_dim) new_W = W.reshape(shape) new_layer = Convolution2D(output_dim,(1,1),strides=(1,1),activation='relu',padding='valid',weights=[new_W,b])