PySpark DataFrame Select, Filter, Where

09.23.2021

Intro

Filtering and subsetting your data is a common task in Data Science. Thanks to spark, we can do similar operation to sql and pandas at scale. In this article, we will learn how to use pyspark dataframes to select and filter data.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Creating a PySpark Data Frame

We begin by creating a spark session and importing a few libraries.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Now, let's create a data frame to work with.

from datetime import datetime, date

rdd = spark.sparkContext.parallelize([
    (60000, 'jan', datetime(2000, 1, 1, 12, 0)),
    (40000, 'feb', datetime(2000, 2, 1, 12, 0)),
    (50000, 'mar', datetime(2000, 3, 1, 12, 0))
])
df = spark.createDataFrame(rdd, schema = ["amount", "month", "date"])

Selecting

We can use the select method to tell pyspark which columns to keep. This method works in a standard way. We simply pass a list of the column names we would like to keep.

df.select(['month', 'amount']).show()
+-----+------+
|month|amount|
+-----+------+
|  jan| 60000|
|  feb| 40000|
|  mar| 50000|
+-----+------+

Filtering

Next, let's look at the filter method. To filter a data frame, we call the filter method and pass a condition. If you are familiar with pandas, this is pretty much the same.

Notice that we chain filters together to further filter the dataset.

df.filter(df['amount'] > 4000).filter(df['month'] != 'jan').show()
+------+-----+-------------------+
|amount|month|               date|
+------+-----+-------------------+
| 40000|  feb|2000-02-01 12:00:00|
| 50000|  mar|2000-03-01 12:00:00|
+------+-----+-------------------+

We can reproduce the above by using conjunctions. For example, we can use & for an "and" query and get the same results.

df.filter((df['amount'] > 4000) & (df['month'] != 'jan')).show()
+------+-----+-------------------+
|amount|month|               date|
+------+-----+-------------------+
| 40000|  feb|2000-02-01 12:00:00|
| 50000|  mar|2000-03-01 12:00:00|
+------+-----+-------------------+

Similar to the above, we can use | for "or" queries.

df.filter((df['amount'] < 50000) | (df['month'] != 'jan')).show()
+------+-----+-------------------+
|amount|month|               date|
+------+-----+-------------------+
| 40000|  feb|2000-02-01 12:00:00|
| 50000|  mar|2000-03-01 12:00:00|
+------+-----+-------------------+

Where

The where method is an alias for filter. Everything you can do with filter, you can do with where. The method is just to provide naming for users who prefer to use the where keyword, like sql.

df.where((df['amount'] < 50000) | (df['month'] != 'jan')).show()
+------+-----+-------------------+
|amount|month|               date|
+------+-----+-------------------+
| 40000|  feb|2000-02-01 12:00:00|
| 50000|  mar|2000-03-01 12:00:00|
+------+-----+-------------------+