PySpark Select

10.02.2021

Intro

Selecting columns is one of the most common operations when working with dataframes. We can select by position or name. We can also select a single or multiple columns. In this article, we will learn how to use PySpark Select.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Selecting Columns By Name

Let's start by creating a Spark Session.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

Now, let's create a dataframe to work with.

rdd = spark.sparkContext.parallelize([    
    ("jan", 2019, 86000,56),
    ("jan", 2020, 71000,30),
    ("jan", 2021, 90000,24),
    
    ("feb", 2019, 99000,40),
    ("feb", 2020, 83000,36),
    ("feb", 2021, 69000,53),
    
    ("mar", 2019, 80000,25),
    ("mar", 2020, 91000,50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

There are many ways to select columns with the select method. We can pass a list of strings or different access properties of the dataframe. Here are some examples.

df.select("month", "year").show()
+-----+----+
|month|year|
+-----+----+
|  jan|2019|
|  jan|2020|
|  jan|2021|
|  feb|2019|
|  feb|2020|
|  feb|2021|
|  mar|2019|
|  mar|2020|
+-----+----+
df.select(df.month, df.year).show()
+-----+----+
|month|year|
+-----+----+
|  jan|2019|
|  jan|2020|
|  jan|2021|
|  feb|2019|
|  feb|2020|
|  feb|2021|
|  mar|2019|
|  mar|2020|
+-----+----+
df.select(df["month"], df["year"]).show()
+-----+----+
|month|year|
+-----+----+
|  jan|2019|
|  jan|2020|
|  jan|2021|
|  feb|2019|
|  feb|2020|
|  feb|2021|
|  mar|2019|
|  mar|2020|
+-----+----+

We can also use the col function from the pyspark.sql.functions package.

from pyspark.sql.functions import col
df.select(col("month"), col("year")).show()
+-----+----+
|month|year|
+-----+----+
|  jan|2019|
|  jan|2020|
|  jan|2021|
|  feb|2019|
|  feb|2020|
|  feb|2021|
|  mar|2019|
|  mar|2020|
+-----+----+

One quick note is that select is a transformation method, which means there is always a new dataframe returned.

df.select("month", "year")
DataFrame[month: string, year: bigint]

If we want to select all columns, we can pass the "*" string to the select.

df.select("*").show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

Selecting Columns By Index

Our next option is to select columns by index. We can do this using normal python slice indexing on the df.columns property.

df.select(df.columns[2:4]).show()
+-------------+--------------------+
|total_revenue|unique_products_sold|
+-------------+--------------------+
|        86000|                  56|
|        71000|                  30|
|        90000|                  24|
|        99000|                  40|
|        83000|                  36|
|        69000|                  53|
|        80000|                  25|
|        91000|                  50|
+-------------+--------------------+
df.select(df.columns[:2]).show()
+-----+----+
|month|year|
+-----+----+
|  jan|2019|
|  jan|2020|
|  jan|2021|
|  feb|2019|
|  feb|2020|
|  feb|2021|
|  mar|2019|
|  mar|2020|
+-----+----+
df.select(df.columns[:-1]).show()
+-----+----+-------------+
|month|year|total_revenue|
+-----+----+-------------+
|  jan|2019|        86000|
|  jan|2020|        71000|
|  jan|2021|        90000|
|  feb|2019|        99000|
|  feb|2020|        83000|
|  feb|2021|        69000|
|  mar|2019|        80000|
|  mar|2020|        91000|
+-----+----+-------------+