PySpark UDF (User Defined Function)

10.13.2021

Intro

Similar to most SQL database such as Postgres, MySQL and SQL server, PySpark allows for user defined functions on its scalable platform. These functions can be run on dataframes or registers to be used on SQL tables. In this article, we will learn how to use PySpark UDF.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

Using the PySpark Collect

Let's start by creating a Spark Session.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
WARNING: An illegal reflective access operation has occurred
WARNING: Illegal reflective access by org.apache.spark.unsafe.Platform (file:/usr/local/spark-3.1.2-bin-hadoop3.2/jars/spark-unsafe_2.12-3.1.2.jar) to constructor java.nio.DirectByteBuffer(long,int)
WARNING: Please consider reporting this to the maintainers of org.apache.spark.unsafe.Platform
WARNING: Use --illegal-access=warn to enable warnings of further illegal reflective access operations
WARNING: All illegal access operations will be denied in a future release
21/10/14 12:27:28 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).

Now, let's create a dataframe to work with.

rdd = spark.sparkContext.parallelize([    
    ("jan", 2019, 86000, 56),
    ("jan", 2020, 71000, 30),
    ("jan", 2021, 90000, 24),
    
    ("feb", 2019, 99000, 40),
    ("feb", 2020, 83000, 36),
    ("feb", 2021, 69000, 53),
    
    ("mar", 2019, 80000, 25),
    ("mar", 2020, 91000, 50)
])
df = spark.createDataFrame(rdd, schema = ["month", "year", "total_revenue", "unique_products_sold"])
df.show()
+-----+----+-------------+--------------------+
|month|year|total_revenue|unique_products_sold|
+-----+----+-------------+--------------------+
|  jan|2019|        86000|                  56|
|  jan|2020|        71000|                  30|
|  jan|2021|        90000|                  24|
|  feb|2019|        99000|                  40|
|  feb|2020|        83000|                  36|
|  feb|2021|        69000|                  53|
|  mar|2019|        80000|                  25|
|  mar|2020|        91000|                  50|
+-----+----+-------------+--------------------+

Creating a UDF

To start creating a UDF, we simply create a normal python function.

def convertRevenue(number):
    return number/1000

Next, we wrap our function in the udf function from pyspark.

from pyspark.sql.functions import udf, col
convertRevenueUdf = udf(lambda value: convertRevenue(value))

Finally, we can use the udf in a few ways. We can use the select method.

df.select(
    col("month"),
    convertRevenueUdf(col("total_revenue")).alias("revenue")
 ).show()
+-----+-------+
|month|revenue|
+-----+-------+
|  jan|   86.0|
|  jan|   71.0|
|  jan|   90.0|
|  feb|   99.0|
|  feb|   83.0|
|  feb|   69.0|
|  mar|   80.0|
|  mar|   91.0|
+-----+-------+

We can also use the withColumn method.

df.withColumn("revenue", convertRevenueUdf(col("total_revenue"))).show()
+-----+----+-------------+--------------------+-------+
|month|year|total_revenue|unique_products_sold|revenue|
+-----+----+-------------+--------------------+-------+
|  jan|2019|        86000|                  56|   86.0|
|  jan|2020|        71000|                  30|   71.0|
|  jan|2021|        90000|                  24|   90.0|
|  feb|2019|        99000|                  40|   99.0|
|  feb|2020|        83000|                  36|   83.0|
|  feb|2021|        69000|                  53|   69.0|
|  mar|2019|        80000|                  25|   80.0|
|  mar|2020|        91000|                  50|   91.0|
+-----+----+-------------+--------------------+-------+

Registering a UDF with SQL

If we would like to run our UDF using PySparkSql, we can reguster our udf instead. Then, we calling sql we will have access to the udf.

spark.udf.register("convertRevenueUdf", convertRevenue)
<function __main__.convertRevenue(number)>
df.createOrReplaceTempView("temp_table")
spark.sql("select month, convertRevenueUdf(total_revenue) as revenue from temp_table") \
     .show()
+-----+-------+
|month|revenue|
+-----+-------+
|  jan|   86.0|
|  jan|   71.0|
|  jan|   90.0|
|  feb|   99.0|
|  feb|   83.0|
|  feb|   69.0|
|  mar|   80.0|
|  mar|   91.0|
+-----+-------+