collect_list conservando el orden basado en otra variable

Estoy tratando de crear una nueva columna de listas en Pyspark usando una agregación groupby en el conjunto de columnas existente. A continuación se proporciona un ejemplo de dataframe de entrada:

------------------------ id | date | value ------------------------ 1 |2014-01-03 | 10 1 |2014-01-04 | 5 1 |2014-01-05 | 15 1 |2014-01-06 | 20 2 |2014-02-10 | 100 2 |2014-03-11 | 500 2 |2014-04-15 | 1500 

La salida esperada es:

 id | value_list ------------------------ 1 | [10, 5, 15, 20] 2 | [100, 500, 1500] 

Los valores dentro de una lista están ordenados por fecha.

Intenté usar collect_list de la siguiente manera:

 from pyspark.sql import functions as F ordered_df = input_df.orderBy(['id','date'],ascending = True) grouped_df = ordered_df.groupby("id").agg(F.collect_list("value")) 

Pero collect_list no garantiza el orden incluso si ordeno el cuadro de datos de entrada por fecha antes de la agregación.

¿Podría alguien ayudar a hacer la agregación conservando el orden basado en una segunda variable (fecha)?

Si recostack las fechas y los valores como una lista, puede ordenar la columna resultante de acuerdo con la fecha usando y udf , y luego mantener solo los valores en el resultado.

 import operator import pyspark.sql.functions as F # create list column grouped_df = input_df.groupby("id") \ .agg(F.collect_list(F.struct("date", "value")) \ .alias("list_col")) # define udf def sorter(l): res = sorted(l, key=operator.itemgetter(0)) return [item[1] for item in res] sort_udf = F.udf(sorter) # test grouped_df.select("id", sort_udf("list_col") \ .alias("sorted_list")) \ .show(truncate = False) +---+----------------+ |id |sorted_list | +---+----------------+ |1 |[10, 5, 15, 20] | |2 |[100, 500, 1500]| +---+----------------+ 
 from pyspark.sql import functions as F from pyspark.sql import Window w = Window.partitionBy('id').orderBy('date') sorted_list_df = input_df.withColumn( 'sorted_list', F.collect_list('value').over(w) )\ .groupBy('id')\ .agg(F.max('sorted_list').alias('sorted_list')) 

Window ejemplos de Window proporcionados por los usuarios a menudo no explican realmente lo que está pasando, así que permítame analizarlo por usted.

Como sabe, el uso de collect_list junto con groupBy dará como resultado una lista de valores no ordenada . Esto se debe a que, dependiendo de cómo se particionen sus datos, Spark agregará valores a su lista tan pronto como encuentre una fila en el grupo. El orden depende de cómo Spark planea su agregación sobre los ejecutores.

Una función de Window permite controlar esa situación, agrupando filas por un determinado valor para que pueda realizar una operación over cada uno de los grupos resultantes:

 w = Window.partitionBy('id').orderBy('date') 
  • partitionBy – quieres grupos / particiones de filas con el mismo id
  • orderBy : desea que cada fila del grupo se ordene por date

Una vez que haya definido el scope de su ventana – “filas con el mismo id , ordenadas por date ” -, puede usarlo para realizar una operación sobre ella, en este caso, una lista de collect_list :

 F.collect_list('value').over(w) 

En este punto, creó una nueva columna ordenada de lista con una lista ordenada de valores, ordenada por fecha, pero todavía tiene filas duplicadas por id . Para recortar las filas duplicadas que desea groupBy id y mantener el valor max para cada grupo:

 .groupBy('id')\ .agg(F.max('sorted_list').alias('sorted_list')) 

La pregunta era para PySpark, pero podría ser útil tenerlo también para Scala Spark.

Preparemos el dataframe de prueba:

 import org.apache.spark.sql.functions._ import org.apache.spark.sql.{DataFrame, Row, SparkSession} import org.apache.spark.sql.expressions.{ Window, UserDefinedFunction} import java.sql.Date import java.time.LocalDate val spark: SparkSession = ... // Out test data set val data: Seq[(Int, Date, Int)] = Seq( (1, Date.valueOf(LocalDate.parse("2014-01-03")), 10), (1, Date.valueOf(LocalDate.parse("2014-01-04")), 5), (1, Date.valueOf(LocalDate.parse("2014-01-05")), 15), (1, Date.valueOf(LocalDate.parse("2014-01-06")), 20), (2, Date.valueOf(LocalDate.parse("2014-02-10")), 100), (2, Date.valueOf(LocalDate.parse("2014-02-11")), 500), (2, Date.valueOf(LocalDate.parse("2014-02-15")), 1500) ) // Create dataframe val df: DataFrame = spark.createDataFrame(data) .toDF("id", "date", "value") df.show() //+---+----------+-----+ //| id| date|value| //+---+----------+-----+ //| 1|2014-01-03| 10| //| 1|2014-01-04| 5| //| 1|2014-01-05| 15| //| 1|2014-01-06| 20| //| 2|2014-02-10| 100| //| 2|2014-02-11| 500| //| 2|2014-02-15| 1500| //+---+----------+-----+ 

Usar UDF

 // Group by id and aggregate date and value to new column date_value val grouped = df.groupBy(col("id")) .agg(collect_list(struct("date", "value")) as "date_value") grouped.show() grouped.printSchema() // +---+--------------------+ // | id| date_value| // +---+--------------------+ // | 1|[[2014-01-03,10],...| // | 2|[[2014-02-10,100]...| // +---+--------------------+ // udf to extract data from Row, sort by needed column (date) and return value val sortUdf: UserDefinedFunction = udf((rows: Seq[Row]) => { rows.map { case Row(date: Date, value: Int) => (date, value) } .sortBy { case (date, value) => date } .map { case (date, value) => value } }) // Select id and value_list val r1 = grouped.select(col("id"), sortUdf(col("date_value")).alias("value_list")) r1.show() // +---+----------------+ // | id| value_list| // +---+----------------+ // | 1| [10, 5, 15, 20]| // | 2|[100, 500, 1500]| // +---+----------------+ 

Ventana de uso

 val window = Window.partitionBy(col("id")).orderBy(col("date")) val sortedDf = df.withColumn("values_sorted_by_date", collect_list("value").over(window)) sortedDf.show() //+---+----------+-----+---------------------+ //| id| date|value|values_sorted_by_date| //+---+----------+-----+---------------------+ //| 1|2014-01-03| 10| [10]| //| 1|2014-01-04| 5| [10, 5]| //| 1|2014-01-05| 15| [10, 5, 15]| //| 1|2014-01-06| 20| [10, 5, 15, 20]| //| 2|2014-02-10| 100| [100]| //| 2|2014-02-11| 500| [100, 500]| //| 2|2014-02-15| 1500| [100, 500, 1500]| //+---+----------+-----+---------------------+ val r2 = sortedDf.groupBy(col("id")) .agg(max("values_sorted_by_date").as("value_list")) r2.show() //+---+----------------+ //| id| value_list| //+---+----------------+ //| 1| [10, 5, 15, 20]| //| 2|[100, 500, 1500]| //+---+----------------+ 

Para asegurarnos de que se realice la clasificación de cada ID, podemos usar sortWithinPartitions:

 from pyspark.sql import functions as F ordered_df = ( input_df .repartition(input_df.id) .sortWithinPartitions(['date']) ) grouped_df = ordered_df.groupby("id").agg(F.collect_list("value"))