Renombrar la columna pivotada y agregada en el dataframe de PySpark

Con un dataframe de la siguiente manera:

from pyspark.sql.functions import avg, first rdd = sc.parallelize( [ (0, "A", 223,"201603", "PORT"), (0, "A", 22,"201602", "PORT"), (0, "A", 422,"201601", "DOCK"), (1,"B", 3213,"201602", "DOCK"), (1,"B", 3213,"201601", "PORT"), (2,"C", 2321,"201601", "DOCK") ] ) df_data = sqlContext.createDataFrame(rdd, ["id","type", "cost", "date", "ship"]) df_data.show() 

Hago un pivote en él,

 df_data.groupby(df_data.id, df_data.type).pivot("date").agg(avg("cost"), first("ship")).show() +---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+ | id|type|201601_avg(cost)|201601_first(ship)()|201602_avg(cost)|201602_first(ship)()|201603_avg(cost)|201603_first(ship)()| +---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+ | 2| C| 2321.0| DOCK| null| null| null| null| | 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT| | 1| B| 3213.0| PORT| 3213.0| DOCK| null| null| +---+----+----------------+--------------------+----------------+--------------------+----------------+--------------------+ 

Pero me dan estos nombres realmente complicados para las columnas. La aplicación de alias en la agregación generalmente funciona, pero debido al pivot en este caso los nombres son aún peores:

 +---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+ | id|type|201601_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201601_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|201602_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201602_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620|201603_(avg(cost),mode=Complete,isDistinct=false) AS cost#1619|201603_(first(ship)(),mode=Complete,isDistinct=false) AS ship#1620| +---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+ | 2| C| 2321.0| DOCK| null| null| null| null| | 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT| | 1| B| 3213.0| PORT| 3213.0| DOCK| null| null| +---+----+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+--------------------------------------------------------------+------------------------------------------------------------------+ 

¿Hay alguna forma de cambiar el nombre de las columnas sobre la marcha en el pivote y la agregación?

Una simple expresión regular debería hacer el truco:

 import re def clean_names(df): p = re.compile("^(\w+?)_([az]+)\((\w+)\)(?:\(\))?") return df.toDF(*[p.sub(r"\1_\3", c) for c in df.columns]) pivoted = df_data.groupby(...).pivot(...).agg(...) clean_names(pivoted).printSchema() ## root ## |-- id: long (nullable = true) ## |-- type: string (nullable = true) ## |-- 201601_cost: double (nullable = true) ## |-- 201601_ship: string (nullable = true) ## |-- 201602_cost: double (nullable = true) ## |-- 201602_ship: string (nullable = true) ## |-- 201603_cost: double (nullable = true) ## |-- 201603_ship: string (nullable = true) 

Si desea conservar el nombre de la función, cambie el patrón de sustitución por ejemplo \1_\2_\3 .

Un enfoque simple será usar un alias después de la función agregada. Comienzo con el df_data spark dataFrame que creaste.

 df_data.groupby(df_data.id, df_data.type).pivot("date").agg(avg("cost").alias("avg_cost"), first("ship").alias("first_ship")).show() +---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+ | id|type|201601_avg_cost|201601_first_ship|201602_avg_cost|201602_first_ship|201603_avg_cost|201603_first_ship| +---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+ | 1| B| 3213.0| PORT| 3213.0| DOCK| null| null| | 2| C| 2321.0| DOCK| null| null| null| null| | 0| A| 422.0| DOCK| 22.0| PORT| 223.0| PORT| +---+----+---------------+-----------------+---------------+-----------------+---------------+-----------------+ 

los nombres de las columnas tendrán la forma de “original_column_name_aliased_column_name”. Para su caso, original_column_name será 201601, aliased_column_name será avg_cost, y el nombre de la columna es 201601_avg_cost (vinculado por el guión bajo “_”).

Puedes alias las agregaciones directamente:

 pivoted = df_data \ .groupby(df_data.id, df_data.type) \ .pivot("date") \ .agg( avg('cost').alias('cost'), first("ship").alias('ship') ) pivoted.printSchema() ##root ##|-- id: long (nullable = true) ##|-- type: string (nullable = true) ##|-- 201601_cost: double (nullable = true) ##|-- 201601_ship: string (nullable = true) ##|-- 201602_cost: double (nullable = true) ##|-- 201602_ship: string (nullable = true) ##|-- 201603_cost: double (nullable = true) ##|-- 201603_ship: string (nullable = true) 

Escribió una función fácil y rápida para hacer esto. ¡Disfrutar! 🙂

 # This function efficiently rename pivot tables' urgly names def rename_pivot_cols(rename_df, remove_agg): """change spark pivot table's default ugly column names at ease. Option 1: remove_agg = True: `2_sum(sum_amt)` --> `sum_amt_2`. Option 2: remove_agg = False: `2_sum(sum_amt)` --> `sum_sum_amt_2` """ for column in rename_df.columns: if remove_agg == True: start_index = column.find('(') end_index = column.find(')') if (start_index > 0 and end_index > 0): rename_df = rename_df.withColumnRenamed(column, column[start_index+1:end_index]+'_'+column[:1]) else: new_column = column.replace('(','_').replace(')','') rename_df = rename_df.withColumnRenamed(column, new_column[2:]+'_'+new_column[:1]) return rename_df