Entendiendo tf.extract_image_patches para extraer parches de una imagen

Encontré el siguiente método tf.extract_image_patches en la API de tensorflow, pero no tengo claro su funcionalidad.

Diga batch_size = 1 , y una imagen es de tamaño 225x225x3 , y queremos extraer parches de tamaño 32x32 .

¿Cómo se comporta exactamente esta función? Específicamente, la documentación menciona la dimensión del tensor de salida como [batch, out_rows, out_cols, ksize_rows * ksize_cols * depth] , pero no se menciona qué out_rows y out_cols son.

Idealmente, dado un tensor de imagen de entrada de tamaño 1x225x225x3 (donde 1 es el tamaño del lote), quiero poder obtener Kx32x32x3 como salida, donde K es el número total de parches y 32x32x3 es la dimensión de cada parche. ¿Hay algo en tensorflow que ya logra esto?

Así es como funciona el método:

  • ksizes se usa para decidir las dimensiones de cada parche, o en otras palabras, cuántos píxeles debe contener cada parche.
  • strides denota la longitud del espacio entre el inicio de un parche y el comienzo del siguiente parche consecutivo dentro de la imagen original.
  • rates es un número que esencialmente significa que nuestro parche debe saltar en píxeles de la imagen original para cada píxel consecutivo que termina en nuestro parche. (El siguiente ejemplo ayuda a ilustrar esto).
  • padding es “VÁLIDO”, lo que significa que cada parche debe estar completamente contenido en la imagen, o “MISMO”, lo que significa que se permite que los parches estén incompletos (los píxeles restantes se rellenarán con ceros).

Aquí hay un código de ejemplo con salida para ayudar a demostrar cómo funciona:

 import tensorflow as tf n = 10 # images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]] # We generate four outputs as follows: # 1. 3x3 patches with stride length 5 # 2. Same as above, but the rate is increased to 2 # 3. 4x4 patches with stride length 7; only one patch should be generated # 4. Same as above, but with padding set to 'SAME' with tf.Session() as sess: print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n' print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval(), '\n\n' print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), '\n\n' print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval() 

Salida:

 [[[[ 1 2 3 11 12 13 21 22 23] [ 6 7 8 16 17 18 26 27 28]] [[51 52 53 61 62 63 71 72 73] [56 57 58 66 67 68 76 77 78]]]] [[[[ 1 3 5 21 23 25 41 43 45] [ 6 8 10 26 28 30 46 48 50]] [[ 51 53 55 71 73 75 91 93 95] [ 56 58 60 76 78 80 96 98 100]]]] [[[[ 1 2 3 4 11 12 13 14 21 22 23 24 31 32 33 34]]]] [[[[ 1 2 3 4 11 12 13 14 21 22 23 24 31 32 33 34] [ 8 9 10 0 18 19 20 0 28 29 30 0 38 39 40 0]] [[ 71 72 73 74 81 82 83 84 91 92 93 94 0 0 0 0] [ 78 79 80 0 88 89 90 0 98 99 100 0 0 0 0 0]]]] 

Entonces, por ejemplo, nuestro primer resultado se parece a lo siguiente:

  * * * 4 5 * * * 9 10 * * * 14 15 * * * 19 20 * * * 24 25 * * * 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 * * * 54 55 * * * 59 60 * * * 64 65 * * * 69 70 * * * 74 75 * * * 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 

Como puede ver, tenemos 2 filas y 2 columnas de parches, que son lo que son out_rows y out_cols .

Para ampliar la respuesta detallada de Neal, hay muchas sutilezas con cero relleno cuando se usa “SAME”, ya que extract_image_patches intenta centrar los parches en la imagen si es posible. Dependiendo de la zancada, puede haber relleno en la parte superior e izquierda, o no, y el primer parche no necesariamente comienza en la parte superior izquierda.

Por ejemplo, extendiendo el ejemplo anterior:

 print tf.extract_image_patches(images, [1, 3, 3, 1], [1, n, n, 1], [1, 1, 1, 1], 'SAME').eval()[0] 

Con un paso de n = 1, la imagen se rellena con ceros alrededor y el primer parche comienza con el relleno. Otras zancadas rellenan la imagen solo en la parte derecha e inferior, o no en absoluto. Con un paso de n = 10, el parche único comienza en el elemento 34 (en el centro de la imagen).

tf.extract_image_patches es implementado por la biblioteca eigen como se describe en esta respuesta . Puede estudiar ese código para ver exactamente cómo se calculan las posiciones de parches y el relleno.