Agrupe por columna y filtre filas con valor máximo en Pyspark

Estoy casi seguro de que esto se ha preguntado anteriormente, pero una búsqueda a través de stackoverflow no respondió a mi pregunta. No es un duplicado de [2] ya que quiero el valor máximo, no el elemento más frecuente. Soy nuevo en pyspark e bash hacer algo realmente simple: quiero agrupar por columna “A” y luego mantener la fila de cada grupo que tiene el valor máximo en la columna “B”. Me gusta esto:

df_cleaned = df.groupBy("A").agg(F.max("B")) 

Desafortunadamente, esto elimina todas las demás columnas: df_cleaned solo contiene las columnas “A” y el valor máximo de B. ¿Cómo mantengo las filas? (“A B C”…)

Puedes hacer esto sin un udf usando una Window .

Considere el siguiente ejemplo:

 import pyspark.sql.functions as f data = [ ('a', 5), ('a', 8), ('a', 7), ('b', 1), ('b', 3) ] df = sqlCtx.createDataFrame(data, ["A", "B"]) df.show() #+---+---+ #| A| B| #+---+---+ #| a| 5| #| a| 8| #| a| 7| #| b| 1| #| b| 3| #+---+---+ 

Cree una Window para particionar por la columna A y use esto para calcular el máximo de cada grupo. Luego filtre las filas de manera que el valor en la columna B sea ​​igual al máximo.

 from pyspark.sql import Window w = Window.partitionBy('A') df.withColumn('maxB', f.max('B').over(w))\ .where(f.col('B') == f.col('maxB'))\ .drop('maxB')\ .show() #+---+---+ #| A| B| #+---+---+ #| a| 8| #| b| 3| #+---+---+ 

O equivalentemente utilizando pyspark-sql :

 df.registerTempTable('table') q = "SELECT A, B FROM (SELECT *, MAX(B) OVER (PARTITION BY A) AS maxB FROM table) M WHERE B = maxB" sqlCtx.sql(q).show() #+---+---+ #| A| B| #+---+---+ #| b| 3| #| a| 8| #+---+---+ 

Otro posible enfoque es aplicar unir el dataframe especificando “leftsemi”. Este tipo de unión incluye todas las columnas del dataframe en el lado izquierdo y ninguna columna en el lado derecho.

Por ejemplo:

 import pyspark.sql.functions as f data = [ ('a', 5, 'c'), ('a', 8, 'd'), ('a', 7, 'e'), ('b', 1, 'f'), ('b', 3, 'g') ] df = sqlContext.createDataFrame(data, ["A", "B", "C"]) df.show() +---+---+---+ | A| B| C| +---+---+---+ | a| 5| c| | a| 8| d| | a| 7| e| | b| 1| f| | b| 3| g| +---+---+---+ 

El valor máximo de la columna B por la columna A se puede seleccionar haciendo:

 df.groupBy('A').agg(f.max('B') +---+---+ | A| B| +---+---+ | a| 8| | b| 3| +---+---+ 

Usando esta expresión como un lado derecho en una semi unión izquierda, y cambiando el nombre de la columna max(B) a su nombre original B , podemos obtener el resultado necesario:

 df.join(df.groupBy('A').agg(f.max('B').alias('B')),on='B',how='leftsemi').show() +---+---+---+ | B| A| C| +---+---+---+ | 3| b| g| | 8| a| d| +---+---+---+ 

El plan físico detrás de esta solución y el de la respuesta aceptada son diferentes y todavía no tengo claro cuál funcionará mejor en grandes marcos de datos.

El mismo resultado se puede obtener usando la syntax de chispa SQL haciendo:

 df.registerTempTable('table') q = '''SELECT * FROM table a LEFT SEMI JOIN ( SELECT A, max(B) as max_B FROM table GROUP BY A ) t ON aA=tA AND aB=t.max_B ''' sqlContext.sql(q).show() +---+---+---+ | A| B| C| +---+---+---+ | b| 3| g| | a| 8| d| +---+---+---+