PySpark provides two major classes, and several other minor classes, to help defined schemas. This allows us to interact with Spark's distributed environment in a type safe way. In this article, we will learn how to use StructType and StructField in PySpark.
To start, let's create a PySpark sessions as normal.
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
Next, let's import everything under the pyspark types module so we can easily use the classes.
from pyspark.sql.types import *
To start, let's discuss the different between StructType and StructField. A StructType is simply a collection of StructFields. A StructField allows us to defined a field name, its type, and if we allow it to be nullable. This is similar to SQL definitions.
schema = StructType([ \
StructField("amount", IntegerType(), True), \
])
schema
StructType(List(StructField(amount,IntegerType,true)))
In the example above, we created a schema using StructType. We only included on field, which is called amount, is of IntegerType, and we allow this field to be nullable.
You can find a full list of type here: https://spark.apache.org/docs/latest/sql-ref-datatypes.html
Let's continue this example and build out our schema a bit more. Here we will add two more types month and date. We will let month be a string type and date will be a timestamp.
schema = StructType([ \
StructField("amount", IntegerType(), True), \
StructField("month", StringType(), True), \
StructField("date", TimestampType(), True), \
])
schema
StructType(List(StructField(amount,IntegerType,true),StructField(month,StringType,true),StructField(date,TimestampType,true)))
Now, let's use this schema when creating a dataframe to ensure that our dataframe conforms to our types. We do this by passing our schema variable to the schema
named parameter in the createDataFrame
method.
from datetime import datetime, date
import pandas as pd
from pyspark.sql import Row
data = [
Row(amount = 20000, month = 'jan', date = datetime(2000, 1, 1, 12, 0)),
Row(amount = 40000, month = 'feb', date = datetime(2000, 2, 1, 12, 0)),
Row(amount = 50000, month = 'mar', date = datetime(2000, 3, 1, 12, 0))
]
df = spark.createDataFrame(data = data, schema=schema)
df.show()
+------+-----+-------------------+
|amount|month| date|
+------+-----+-------------------+
| 20000| jan|2000-01-01 12:00:00|
| 40000| feb|2000-02-01 12:00:00|
| 50000| mar|2000-03-01 12:00:00|
+------+-----+-------------------+
To view our dataframe's schema, we can use the printSchema
method.
df.printSchema()
root
|-- amount: integer (nullable = true)
|-- month: string (nullable = true)
|-- date: timestamp (nullable = true)