PySpark Collect

10.03.2021

Intro

The dataframe collect method is used to return the rows in a dataframe as a list of PySpark Row classes. This is used to retrieve data on small dataframes so that you can inspect and iterate over the data. Large datasets will not be good as all the data is in memory and will likely throw an out of memory issue. In this article, we will learn how to use PySpark Collect.

Side note, if you want to get Rows for large data sets, we can use the take and skip methods. Those are covered in separate articles.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Using the PySpark Collect

Let's start by creating a Spark Session.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
WARNING: An illegal reflective access operation has occurred
WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/spark-3.1.2-bin-hadoop3.2/jars/spark-unsafe_2.12-3.1.2.jar) to constructor java.nio.DirectByteBuffer(long,int)
WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform
WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations
WARNING: All illegal access operations will be denied in a future release
21/10/04 01:24:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).

Now, let's create a dataframe to work with.

rdd = spark.sparkContext.parallelize([    
    ("jan", 2019, 86000,56),
    ("jan", 2020, 71000,30),
    ("jan", 2021, 90000,24),
    
    ("feb", 2019, 99000,40),
    ("feb", 2020, 83000,36),
    ("feb", 2021, 69000,53),
    
    ("mar", 2019, 80000,25),
    ("mar", 2020, 91000,50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

Since our data set is small, we can use the collect method on our dataframe. Below you can see a list of Row classes.

df.collect()
[Row(month='jan', year=2019, total_revenue=86000, unique_products_sold=56),
 Row(month='jan', year=2020, total_revenue=71000, unique_products_sold=30),
 Row(month='jan', year=2021, total_revenue=90000, unique_products_sold=24),
 Row(month='feb', year=2019, total_revenue=99000, unique_products_sold=40),
 Row(month='feb', year=2020, total_revenue=83000, unique_products_sold=36),
 Row(month='feb', year=2021, total_revenue=69000, unique_products_sold=53),
 Row(month='mar', year=2019, total_revenue=80000, unique_products_sold=25),
 Row(month='mar', year=2020, total_revenue=91000, unique_products_sold=50)]
for row in df.collect():
    print(row['month'])
jan
jan
jan
feb
feb
feb
mar
mar

We can also index the returned value of collect like a matrix. Here are some examples.

rows = df.collect()

# Get first row
rows[0]
Row(month='jan', year=2019, total_revenue=86000, unique_products_sold=56)
# second value of the first row
rows[0][1]
2019