Often when viewing data, we have it stored in an observation format. Sometimes, we would like to turn a category feature into columns. We can use the Pivot method for this. In this article, we will learn how to use PySpark Pivot.
The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml
, paste the following code, then run docker-compose up
. You will then see a link in the console to open up and access a jupyter notebook.
version: '3'
services:
spark:
image: jupyter/pyspark-notebook
ports:
- "8888:8888"
- "4040-4080:4040-4080"
volumes:
- ./notebooks:/home/jovyan/work/notebooks/
We begin by creating a spark session and importing a few libraries.
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
Next, let's create some fake sales data in a data frame.
from datetime import datetime, date
rdd = spark.sparkContext.parallelize([
("jan", 2019, 86000,56),
("jan", 2019, 96000,56),
("jan", 2019, 76000,56),
("jan", 2020, 81000,30),
("jan", 2020, 71000,30),
("jan", 2021, 90000,24),
("feb", 2019, 99000,40),
("feb", 2019, 89000,40),
("feb", 2020, 83000,36),
("feb", 2021, 79000,53),
("feb", 2021, 69000,53),
("mar", 2019, 80000,25),
("mar", 2019, 84000,25),
("mar", 2020, 91000,50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
| jan|2019| 86000| 56|
| jan|2019| 96000| 56|
| jan|2019| 76000| 56|
| jan|2020| 81000| 30|
| jan|2020| 71000| 30|
| jan|2021| 90000| 24|
| feb|2019| 99000| 40|
| feb|2019| 89000| 40|
| feb|2020| 83000| 36|
| feb|2021| 79000| 53|
| feb|2021| 69000| 53|
| mar|2019| 80000| 25|
| mar|2019| 84000| 25|
| mar|2020| 91000| 50|
+-----+----+-------------+--------------------+
Let's say we would like to aggregate the above data to show averages. And, we would like to view the months as columns to easily compare.
df.groupBy('month')\
.pivot('year')\
.avg('total_revenue')\
.fillna(0).show()
+-----+-------+-------+-------+
|month| 2019| 2020| 2021|
+-----+-------+-------+-------+
| feb|94000.0|83000.0|74000.0|
| mar|82000.0|91000.0| 0.0|
| jan|86000.0|76000.0|90000.0|
+-----+-------+-------+-------+
One minor note, is that the above operation is pretty slow. If you have a large data set, you will start to get very poor response times. To speed this up, you can provide a list of the pivot columns. Here is a quick example.
years = ["2019", "2020", "2021"]
pivotDF = df.groupBy("month").pivot("year", years).avg('total_revenue')
pivotDF.show()
+-----+-------+-------+-------+
|month| 2019| 2020| 2021|
+-----+-------+-------+-------+
| feb|94000.0|83000.0|74000.0|
| mar|82000.0|91000.0| null|
| jan|86000.0|76000.0|90000.0|
+-----+-------+-------+-------+