PySpark Pivot (rows to columns)

09.26.2021

Intro

Often when viewing data, we have it stored in an observation format. Sometimes, we would like to turn a category feature into columns. We can use the Pivot method for this. In this article, we will learn how to use PySpark Pivot.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Creating a PySpark Data Frame

We begin by creating a spark session and importing a few libraries.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Next, let's create some fake sales data in a data frame.

from datetime import datetime, date

rdd = spark.sparkContext.parallelize([    
    ("jan", 2019, 86000,56),
    ("jan", 2019, 96000,56),
    ("jan", 2019, 76000,56),
    ("jan", 2020, 81000,30),
    ("jan", 2020, 71000,30),
    ("jan", 2021, 90000,24),
    
    ("feb", 2019, 99000,40),
    ("feb", 2019, 89000,40),
    ("feb", 2020, 83000,36),
    ("feb", 2021, 79000,53),
    ("feb", 2021, 69000,53),
    
    ("mar", 2019, 80000,25),
    ("mar", 2019, 84000,25),
    ("mar", 2020, 91000,50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2019|        96000|                  56|
|  jan|2019|        76000|                  56|
|  jan|2020|        81000|                  30|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2019|        89000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        79000|                  53|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2019|        84000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

Let's say we would like to aggregate the above data to show averages. And, we would like to view the months as columns to easily compare.

df.groupBy('month')\
    .pivot('year')\
    .avg('total_revenue')\
    .fillna(0).show()
+-----+-------+-------+-------+
|month|   2019|   2020|   2021|
+-----+-------+-------+-------+
|  feb|94000.0|83000.0|74000.0|
|  mar|82000.0|91000.0|    0.0|
|  jan|86000.0|76000.0|90000.0|
+-----+-------+-------+-------+

One minor note, is that the above operation is pretty slow. If you have a large data set, you will start to get very poor response times. To speed this up, you can provide a list of the pivot columns. Here is a quick example.

years = ["2019", "2020", "2021"]
pivotDF = df.groupBy("month").pivot("year", years).avg('total_revenue')
pivotDF.show()
+-----+-------+-------+-------+
|month|   2019|   2020|   2021|
+-----+-------+-------+-------+
|  feb|94000.0|83000.0|74000.0|
|  mar|82000.0|91000.0|   null|
|  jan|86000.0|76000.0|90000.0|
+-----+-------+-------+-------+