Comment sélectionner la première ligne de chaque groupe?


143

J'ai un DataFrame généré comme suit:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

Les résultats ressemblent à:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

Comme vous pouvez le voir, le DataFrame est trié par Hourordre croissant, puis par TotalValueordre décroissant.

Je souhaite sélectionner la première ligne de chaque groupe, c'est-à-dire

  • dans le groupe Heure == 0 sélectionnez (0, cat26,30.9)
  • dans le groupe Heure == 1 sélectionnez (1, cat67,28.5)
  • dans le groupe Heure == 2 sélectionnez (2, cat56,39.6)
  • etc

Ainsi, la sortie souhaitée serait:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

Il peut être utile de pouvoir également sélectionner les N premières lignes de chaque groupe.

Toute aide est grandement appréciée.

Réponses:


232

Fonctions de la fenêtre :

Quelque chose comme ça devrait faire l'affaire:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Cette méthode sera inefficace en cas de biais significatif des données.

Agrégation SQL simple suivie dejoin :

Vous pouvez également vous joindre avec un bloc de données agrégé:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Il conservera les valeurs en double (s'il y a plus d'une catégorie par heure avec la même valeur totale). Vous pouvez les supprimer comme suit:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Utilisation de la commande surstructs :

Astuce, bien que pas très bien testée, qui ne nécessite ni jointures ni fonctions de fenêtre:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Avec l'API DataSet (Spark 1.6+, 2.0+):

Spark 1.6 :

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 ou version ultérieure :

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

Les deux dernières méthodes peuvent tirer parti de la combinaison côté carte et ne nécessitent pas de mélange complet, donc la plupart du temps, elles devraient présenter de meilleures performances par rapport aux fonctions de fenêtre et aux jointures. Ces cannes peuvent également être utilisées avec le streaming structuré en completedmode sortie.

N'utilisez pas :

df.orderBy(...).groupBy(...).agg(first(...), ...)

Cela peut sembler fonctionner (en particulier dans le localmode) mais ce n'est pas fiable (voir SPARK-16207 , merci à Tzach Zohar pour avoir lié le problème JIRA pertinent et SPARK-30335 ).

La même remarque s'applique à

df.orderBy(...).dropDuplicates(...)

qui utilise en interne un plan d'exécution équivalent.


3
Il semble que depuis Spark 1.6, il s'agit de row_number () au lieu de rowNumber
Adam Szałucha

À propos de Ne pas utiliser df.orderBy (...). GropBy (...). Dans quelles circonstances pouvons-nous nous fier à orderBy (...)? ou si nous ne pouvons pas être sûrs si orderBy () va donner le bon résultat, quelles alternatives avons-nous?
Ignacio Alorre du

Je suis peut-être en train d' oublier quelque chose, mais en général, il est recommandé d' éviter groupByKey , à la place, ReduceByKey devrait être utilisé. De plus, vous économiserez une ligne.
Thomas

3
@Thomas évitant groupBy / groupByKey est juste lorsqu'il s'agit de RDD, vous remarquerez que l'API Dataset n'a même pas de fonction reductionByKey.
suote


16

Pour Spark 2.0.2 avec regroupement par plusieurs colonnes:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

8

C'est exactement la même chose que la réponse de zero323 , mais en mode requête SQL.

En supposant que le dataframe est créé et enregistré comme

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Fonction de fenêtre:

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Agrégation SQL simple suivie d'une jointure:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Utilisation de la commande sur les structures:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets façon et ne font pas les mêmes que dans la réponse originale


2

Le motif est groupé par clés => faire quelque chose pour chaque groupe, par exemple réduire => retourner au dataframe

Je pensais que l'abstraction Dataframe était un peu encombrante dans ce cas, j'ai donc utilisé la fonctionnalité RDD

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)

1

La solution ci-dessous ne fait qu'un groupBy et extrait les lignes de votre dataframe qui contiennent la maxValue en un seul coup. Pas besoin de jointures supplémentaires ou de Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}

Mais il mélange tout d'abord. Ce n'est guère une amélioration (peut-être pas pire que les fonctions de fenêtre, selon les données).
Alper t. Turker

vous avez un groupe en premier lieu, qui déclenchera un mélange. Ce n'est pas pire que la fonction de fenêtre car dans une fonction de fenêtre, elle va évaluer la fenêtre pour chaque ligne unique dans le dataframe.
elghoto

1

Une bonne façon de faire cela avec l'API dataframe consiste à utiliser la logique argmax comme ceci

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+

0

Ici, vous pouvez faire comme ça -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show

-2

Nous pouvons utiliser la fonction de fenêtre rank () (où vous choisiriez le rang = 1) rank ajoute simplement un nombre pour chaque ligne d'un groupe (dans ce cas, ce serait l'heure)

voici un exemple. (de https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

val dataset = spark.range(9).withColumn("bucket", 'id % 3)

import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)

scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
|  0|     0|   1|
|  3|     0|   2|
|  6|     0|   3|
|  1|     1|   1|
|  4|     1|   2|
|  7|     1|   3|
|  2|     2|   1|
|  5|     2|   2|
|  8|     2|   3|
+---+------+----+
En utilisant notre site, vous reconnaissez avoir lu et compris notre politique liée aux cookies et notre politique de confidentialité.
Licensed under cc by-sa 3.0 with attribution required.