PySpark FlatMap

10.15.2021

Intro

The PySpark flatMap method allows use to iterate over rows in an RDD and transform each item. This method is similar to method, but will produce a flat list or array of data instead of mapping to new objects. This is helpful when you want to get a simple list of data from rows to iterate over or encode. In this article, we will learn how to use PySpark flatMap.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Creating the Data

Let's start by creating a Spark Session.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Now, let's create a dataframe to work with.

rdd = spark.sparkContext.parallelize([    
    ("jan", 2019, 86000, 56),
    ("jan", 2020, 71000, 30),
    ("jan", 2021, 90000, 24),
    
    ("feb", 2019, 99000, 40),
    ("feb", 2020, 83000, 36),
    ("feb", 2021, 69000, 53),
    
    ("mar", 2019, 80000, 25),
    ("mar", 2020, 91000, 50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

Using FlatMap

To use flatMap, we first have to use the .rdd property of our DataFrame. The flatMap function will return a list of all the items in our returned object. In the example below, we return a modified tuple of data, and the flatMap reduces this to a list.

rdd = df.rdd.flatMap(lambda x: (x[0], x[1], x[2], x[3] * 100))  

rdd.collect()
['jan',
 2019,
 86000,
 5600,
 'jan',
 2020,
 71000,
 3000,
 'jan',
 2021,
 90000,
 2400,
 'feb',
 2019,
 99000,
 4000,
 'feb',
 2020,
 83000,
 3600,
 'feb',
 2021,
 69000,
 5300,
 'mar',
 2019,
 80000,
 2500,
 'mar',
 2020,
 91000,
 5000]

If we want, we can also use the column names in map.

rdd = df.rdd.flatMap(lambda x: (x.month, x.year))  

rdd.collect()
['jan',
 2019,
 'jan',
 2020,
 'jan',
 2021,
 'feb',
 2019,
 'feb',
 2020,
 'feb',
 2021,
 'mar',
 2019,
 'mar',
 2020]

Let's compare this to the basic map function. You can see here we get a list of tuples instead of a flat list.

rdd = df.rdd.map(lambda x: (x.month, x.year))  

rdd.collect()
[('jan', 2019),
 ('jan', 2020),
 ('jan', 2021),
 ('feb', 2019),
 ('feb', 2020),
 ('feb', 2021),
 ('mar', 2019),
 ('mar', 2020)]

One helpful parameters when working with RDDs is the preservesPartitioning. This will keep your RDDs partition instead of modifying the data in each partition.

rdd = df.rdd.flatMap(lambda x: (x.month, x.year), preservesPartitioning=True)  

rdd.collect()
['jan',
 2019,
 'jan',
 2020,
 'jan',
 2021,
 'feb',
 2019,
 'feb',
 2020,
 'feb',
 2021,
 'mar',
 2019,
 'mar',
 2020]