PySpark WithColumn

10.04.2021

Intro

The withColumn method allow us to add columns, modify their types, modify their values and more. It is one of the most commonly used methods for PySpark. In this article, we will learn how to use PySpark withColumn

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Using the PySpark Collect

Let's start by creating a Spark Session.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
WARNING: An illegal reflective access operation has occurred
WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/spark-3.1.2-bin-hadoop3.2/jars/spark-unsafe_2.12-3.1.2.jar) to constructor java.nio.DirectByteBuffer(long,int)
WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform
WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations
WARNING: All illegal access operations will be denied in a future release
21/10/04 05:38:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).

Now, let's create a dataframe to work with.

rdd = spark.sparkContext.parallelize([    
    ("jan", 2019, 86000,56),
    ("jan", 2020, 71000,30),
    ("jan", 2021, 90000,24),
    
    ("feb", 2019, 99000,40),
    ("feb", 2020, 83000,36),
    ("feb", 2021, 69000,53),
    
    ("mar", 2019, 80000,25),
    ("mar", 2020, 91000,50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

Creating a new Column

To create a new column, we can use the withColumn function. First we pass the new name of the function, then we can pass the lit, which takes a literally and turns it into a PySpark Column.

from pyspark.sql.functions import lit

df.withColumn("Source", lit("Web")).show()
+-----+----+-------------+--------------------+------+
|month|year|total_revenue|unique_products_sold|Source|
+-----+----+-------------+--------------------+------+
|  jan|2019|        86000|                  56|   Web|
|  jan|2020|        71000|                  30|   Web|
|  jan|2021|        90000|                  24|   Web|
|  feb|2019|        99000|                  40|   Web|
|  feb|2020|        83000|                  36|   Web|
|  feb|2021|        69000|                  53|   Web|
|  mar|2019|        80000|                  25|   Web|
|  mar|2020|        91000|                  50|   Web|
+-----+----+-------------+--------------------+------+
from pyspark.sql.functions import col

df.withColumn("AdjustedRevenue", col("total_revenue") / 1000).show()
+-----+----+-------------+--------------------+---------------+
|month|year|total_revenue|unique_products_sold|AdjustedRevenue|
+-----+----+-------------+--------------------+---------------+
|  jan|2019|        86000|                  56|           86.0|
|  jan|2020|        71000|                  30|           71.0|
|  jan|2021|        90000|                  24|           90.0|
|  feb|2019|        99000|                  40|           99.0|
|  feb|2020|        83000|                  36|           83.0|
|  feb|2021|        69000|                  53|           69.0|
|  mar|2019|        80000|                  25|           80.0|
|  mar|2020|        91000|                  50|           91.0|
+-----+----+-------------+--------------------+---------------+

Changing type of a Column

To change the type of a column, we can use the col function to access our column and the cast method on the col class to change the type.

from pyspark.sql.functions import col

df.withColumn("total_revenue", col("total_revenue").cast("Integer")).show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

Update the Value of a Column

Similar to how we added a new column using an existing column, we can simply overwrite the column to adjust the value.

from pyspark.sql.functions import col

df.withColumn("total_revenue", col("total_revenue") * 100).show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|      8600000|                  56|
|  jan|2020|      7100000|                  30|
|  jan|2021|      9000000|                  24|
|  feb|2019|      9900000|                  40|
|  feb|2020|      8300000|                  36|
|  feb|2021|      6900000|                  53|
|  mar|2019|      8000000|                  25|
|  mar|2020|      9100000|                  50|
+-----+----+-------------+--------------------+

Rename a Column

If we would like to rename a column, we can use a similar method the withColumnRenamed method.

df.withColumnRenamed("total_revenue", "Total Revenue").show()
+-----+----+-------------+--------------------+
|month|year|Total Revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+