Nueva columna Dataframe como función genérica de otras filas (chispa)

¿Cómo creo eficientemente una nueva columna en un DataFrame que es una función de otras filas en spark ?

Esta es una implementación de spark del problema que describí aquí :

 from nltk.metrics.distance import edit_distance as edit_dist from pyspark.sql.functions import col, udf from pyspark.sql.types import IntegerType d = { 'id': [1, 2, 3, 4, 5, 6], 'word': ['cat', 'hat', 'hag', 'hog', 'dog', 'elephant'] } spark_df = sqlCtx.createDataFrame(pd.DataFrame(d)) words_list = list(spark_df.select('word').collect()) get_n_similar = udf( lambda word: len( [ w for w in words_list if (w['word'] != word) and (edit_dist(w['word'], word) < 2) ] ), IntegerType() ) spark_df.withColumn('n_similar', get_n_similar(col('word'))).show() 

Salida:

 +---+--------+---------+ |id |word |n_similar| +---+--------+---------+ |1 |cat |1 | |2 |hat |2 | |3 |hag |2 | |4 |hog |2 | |5 |dog |1 | |6 |elephant|0 | +---+--------+---------+ 

El problema aquí es que no conozco una manera de decirle a spark que compare la fila actual con las otras filas en el Dataframe sin primero recostackr los valores en una list . ¿Hay una manera de aplicar una función genérica de otras filas sin llamar a collect ?

El problema aquí es que no conozco una manera de decirle a spark que compare la fila actual con las otras filas en el Dataframe sin primero recostackr los valores en una lista.

UDF no es una opción aquí (no puede hacer referencia a DataFrame distribuido en udf ) La traducción directa de su lógica es producto cartesiano y agregado:

 from pyspark.sql.functions import levenshtein, col result = (spark_df.alias("l") .crossJoin(spark_df.alias("r")) .where(levenshtein("l.word", "r.word") < 2) .where(col("l.word") != col("r.word")) .groupBy("l.id", "l.word") .count()) 

pero en la práctica, debe intentar hacer algo más eficiente: una concordancia eficiente de cadenas en Apache Spark

Dependiendo del problema, debe tratar de encontrar otras aproximaciones para evitar el producto cartesiano completo.

Si desea mantener los datos sin coincidencias, puede omitir un filtro:

 (spark_df.alias("l") .crossJoin(spark_df.alias("r")) .where(levenshtein("l.word", "r.word") < 2) .groupBy("l.id", "l.word") .count() .withColumn("count", col("count") - 1)) 

o (más lento, pero más genérico), únase con la referencia:

 (spark_df .select("id", "word") .distinct() .join(result, ["id", "word"], "left") .na.fill(0))