Fusion de plusieurs trames de données par ligne dans PySpark


21

J'ai 10 trames de données pyspark.sql.dataframe.DataFrame, obtenues à partir randomSplitde (td1, td2, td3, td4, td5, td6, td7, td8, td9, td10) = td.randomSplit([.1, .1, .1, .1, .1, .1, .1, .1, .1, .1], seed = 100)maintenant , je veux rejoindre 9 td« s en une seule trame de données, comment dois - je faire?

J'ai déjà essayé avec unionAll, mais cette fonction n'accepte que deux arguments.

td1_2 = td1.unionAll(td2) 
# this is working fine

td1_2_3 = td1.unionAll(td2, td3) 
# error TypeError: unionAll() takes exactly 2 arguments (3 given)

Existe-t-il un moyen de combiner plus de deux trames de données en ligne?

Le but de cela est que je fais la validation croisée 10 fois manuellement sans utiliser la CrossValidatorméthode PySpark , donc en prenant 9 en formation et 1 en données de test, puis je le répéterai pour d'autres combinaisons.


1
Cela ne répond pas directement à la question, mais ici, je donne une suggestion pour améliorer la méthode de dénomination afin qu'au final, nous n'ayons pas à taper, par exemple: [td1, td2, td3, td4, td5, td6, td7 , td8, td9, td10]. Imaginez faire cela pour un CV 100 fois. Voici ce que je vais faire: portions = [0.1] * 10 cv = df7.randomSplit (portions) folds = list (range (10)) for i in range (10): test_data = cv [i] fold_no_i = folds [: i] + replis [i + 1:] train_data = cv [fold_no_i [0]] pour j dans fold_no_i [1:]: train_data = train_data.union (cv [j])
ngoc thoag

Réponses:


37

Volé à: /programming/33743978/spark-union-of-multiple-rdds

En dehors de l'enchaînement des syndicats, c'est la seule façon de le faire pour les DataFrames.

from functools import reduce  # For Python 3.x
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionAll, dfs)

unionAll(td2, td3, td4, td5, td6, td7, td8, td9, td10)

Ce qui se passe, c'est qu'il prend tous les objets que vous avez passés en tant que paramètres et les réduit à l'aide de unionAll (cette réduction provient de Python, pas de la réduction Spark bien qu'ils fonctionnent de manière similaire), ce qui le réduit finalement à un DataFrame.

Si au lieu de DataFrames ce sont des RDD normaux, vous pouvez passer une liste d'entre eux à la fonction d'union de votre SparkContext

EDIT: Pour votre objectif, je propose une méthode différente, car vous devrez répéter cette union 10 fois pour vos différents plis pour la validation croisée, j'ajouterais des étiquettes pour lesquelles appartient une ligne et je filtrerais simplement votre DataFrame pour chaque pli en fonction de l'étiquette


(+1) Une belle solution de contournement. Cependant, il doit y avoir une fonction qui permet la concaténation de plusieurs trames de données. Serait assez pratique!
Dawny33

Je ne suis pas en désaccord avec cela
Jan van der Vegt

@JanvanderVegt Merci, cela fonctionne et l'idée d'ajouter des étiquettes pour filtrer l'ensemble de données de formation et de test, je l'ai déjà fait. Merci beaucoup pour votre aide.
krishna Prasad

@Jan van der Vegt Pouvez-vous appliquer la même logique pour Join et répondre à cette question
GeorgeOfTheRF


6

Parfois, lorsque les cadres de données à combiner n'ont pas le même ordre de colonnes, il est préférable de df2.select (df1.columns) afin de s'assurer que les deux df ont le même ordre de colonnes avant l'union.

import functools 

def unionAll(dfs):
    return functools.reduce(lambda df1,df2: df1.union(df2.select(df1.columns)), dfs) 

Exemple:

df1 = spark.createDataFrame([[1,1],[2,2]],['a','b'])
# different column order. 
df2 = spark.createDataFrame([[3,333],[4,444]],['b','a']) 
df3 = spark.createDataFrame([555,5],[666,6]],['b','a']) 

unioned_df = unionAll([df1, df2, df3])
unioned_df.show() 

entrez la description de l'image ici

sinon, il générerait le résultat ci-dessous à la place.

from functools import reduce  # For Python 3.x
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionAll, dfs) 

unionAll(*[df1, df2, df3]).show()

entrez la description de l'image ici


2

Que diriez-vous d'utiliser la récursivité?

def union_all(dfs):
    if len(dfs) > 1:
        return dfs[0].unionAll(union_all(dfs[1:]))
    else:
        return dfs[0]

td = union_all([td1, td2, td3, td4, td5, td6, td7, td8, td9, td10])
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.