PySpark DataFrame Aggregations

09.24.2021

Intro

One main feature you will use in Spark is aggregation. This will help with exploratory data analysis and building dashboards that scale. In this article, we will learn how to use pyspark aggregations.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Creating a PySpark Data Frame

We begin by creating a spark session and importing a few libraries.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Next, let's create some fake sales data in a data frame.

from datetime import datetime, date

rdd = spark.sparkContext.parallelize([    
    ("jan", 2019, 86000,56),
    ("jan", 2020, 81000,30),
    ("jan", 2021, 90000,24),
    ("feb", 2019, 99000,40),
    ("feb", 2020, 83000,36),
    ("feb", 2021, 79000,53),
    ("mar", 2019, 80000,25),
    ("mar", 2020, 91000,50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        81000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        79000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

To run aggregates, we can use the groupBy method then call a summary function on the grouped data. For example, we can group our sales data by month, then call count to get the number of rows per a month. Below are a few examples of aggregate functions, count, mean, max, min, and sum.

df.groupBy('month').count().show()
+-----+-----+
|month|count|
+-----+-----+
|  feb|    3|
|  mar|    2|
|  jan|    3|
+-----+-----+
df.groupBy('month').mean().show()
+-----+---------+------------------+-------------------------+
|month|avg(year)|avg(total_revenue)|avg(unique_products_sold)|
+-----+---------+------------------+-------------------------+
|  feb|   2020.0|           87000.0|                     43.0|
|  mar|   2019.5|           85500.0|                     37.5|
|  jan|   2020.0| 85666.66666666667|       36.666666666666664|
+-----+---------+------------------+-------------------------+
df.groupBy('month').max().show()
+-----+---------+------------------+-------------------------+
|month|max(year)|max(total_revenue)|max(unique_products_sold)|
+-----+---------+------------------+-------------------------+
|  feb|     2021|             99000|                       53|
|  mar|     2020|             91000|                       50|
|  jan|     2021|             90000|                       56|
+-----+---------+------------------+-------------------------+
df.groupBy('month').min().show()
+-----+---------+------------------+-------------------------+
|month|min(year)|min(total_revenue)|min(unique_products_sold)|
+-----+---------+------------------+-------------------------+
|  feb|     2019|             79000|                       36|
|  mar|     2019|             80000|                       25|
|  jan|     2019|             81000|                       24|
+-----+---------+------------------+-------------------------+
df.groupBy('month').sum().show()
+-----+---------+------------------+-------------------------+
|month|sum(year)|sum(total_revenue)|sum(unique_products_sold)|
+-----+---------+------------------+-------------------------+
|  feb|     6060|            261000|                      129|
|  mar|     4039|            171000|                       75|
|  jan|     6060|            257000|                      110|
+-----+---------+------------------+-------------------------+

We can also group by multiple values by passing a list of parameters to groupBy.

df.groupBy('month', 'year').count().show()
+-----+----+-----+
|month|year|count|
+-----+----+-----+
|  feb|2020|    1|
|  feb|2021|    1|
|  jan|2021|    1|
|  mar|2019|    1|
|  feb|2019|    1|
|  jan|2019|    1|
|  mar|2020|    1|
|  jan|2020|    1|
+-----+----+-----+

Also, note that these aggregations return dataframes. So, you can combine the aggregation with other dataframe methods, like filter.

df.groupBy('month').sum().filter("sum(total_revenue) > 200000").show()
+-----+---------+------------------+-------------------------+
|month|sum(year)|sum(total_revenue)|sum(unique_products_sold)|
+-----+---------+------------------+-------------------------+
|  feb|     6060|            261000|                      129|
|  jan|     6060|            257000|                      110|
+-----+---------+------------------+-------------------------+