PySpark provides a stats library located in pyspark.ml.stat
that gives us a few tests and classes to use to do common statistical flows. The are based on vectors and provide scalable operations. In this article, we will learn how to use PySpark's stats package.
The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml
, paste the following code, then run docker-compose up
. You will then see a link in the console to open up and access a jupyter notebook.
version: '3'
services:
spark:
image: jupyter/pyspark-notebook
ports:
- "8888:8888"
- "4040-4080:4040-4080"
volumes:
- ./notebooks:/home/jovyan/work/notebooks/
Let's start by creating a Spark Session.
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()
from pyspark.ml.linalg import Vectors
from pyspark.ml.stat import ChiSquareTest
dataset = [[0, Vectors.dense([0, 0, 1])],
[0, Vectors.dense([1, 0, 1])],
[1, Vectors.dense([2, 1, 1])],
[1, Vectors.dense([3, 1, 1])]]
dataset = spark.createDataFrame(dataset, ["label", "features"])
chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')
chiSqResult.select("degreesOfFreedom").collect()[0]
Row(degreesOfFreedom=[3, 1, 0])
chiSqResult = ChiSquareTest.test(dataset, 'features', 'label', True)
chiSqResult.show()
+------------+-------------------+----------------+---------+
|featureIndex| pValue|degreesOfFreedom|statistic|
+------------+-------------------+----------------+---------+
| 0| 0.2614641299491107| 3| 4.0|
| 1|0.04550026389635764| 1| 4.0|
| 2| 1.0| 0| 0.0|
+------------+-------------------+----------------+---------+
row = chiSqResult.orderBy("featureIndex").collect()
row[0].statistic
4.0
from pyspark.ml.stat import Correlation
pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()
pearsonCorr
pearsonCorr[0][0]
21/10/30 04:31:43 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
21/10/30 04:31:43 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
21/10/30 04:31:43 WARN PearsonCorrelation: Pearson correlation matrix contains NaN values.
DenseMatrix(3, 3, [1.0, 0.8944, nan, 0.8944, 1.0, nan, nan, nan, 1.0], False)
spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()
spearmanCorr
spearmanCorr[0][0]
21/10/30 04:32:10 WARN PearsonCorrelation: Pearson correlation matrix contains NaN values.
DenseMatrix(3, 3, [1.0, 0.8944, nan, 0.8944, 1.0, nan, nan, nan, 1.0], False)
from pyspark.ml.stat import KolmogorovSmirnovTest
dataset = [[-1.0], [0.0], [1.0]]
dataset = spark.createDataFrame(dataset, ['sample'])
ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first()
ksResult
Row(pValue=0.9999753186701124, statistic=0.1746780794018764)
from pyspark.ml.stat import MultivariateGaussian
from pyspark.ml.linalg import DenseMatrix, Vectors
m = MultivariateGaussian(Vectors.dense([11,12]), DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0)))
(m.mean, m.cov.toArray())
(DenseVector([11.0, 12.0]),
array([[1., 5.],
[3., 2.]]))
from pyspark.ml.stat import Summarizer
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors
summarizer = Summarizer.metrics("mean", "count")
sc = spark.sparkContext
df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)
+-----------------------------------+
|aggregate_metrics(features, weight)|
+-----------------------------------+
|{[1.0,1.0,1.0], 1} |
+-----------------------------------+
df.select(summarizer.summary(df.features)).show(truncate=False)
+--------------------------------+
|aggregate_metrics(features, 1.0)|
+--------------------------------+
|{[1.0,1.5,2.0], 2} |
+--------------------------------+
import pandas as pd
# https://stackoverflow.com/questions/57014043/reading-data-from-url-using-spark-databricks-platform
dataurl = 'https://raw.githubusercontent.com/drujensen/heart-disease/master/data/heart-disease.csv'
df = pd.read_csv(dataurl)
df.head()
age | sex | cp | trestbps | chol | fbs | restecg | thalach | exang | oldpeak | slope | ca | thal | num | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 28 | 1 | 2 | 130 | 132 | 0 | 2 | 185 | 0 | 0.0 | ? | ? | ? | 0 |
1 | 29 | 1 | 2 | 120 | 243 | 0 | 0 | 160 | 0 | 0.0 | ? | ? | ? | 0 |
2 | 29 | 1 | 2 | 140 | ? | 0 | 0 | 170 | 0 | 0.0 | ? | ? | ? | 0 |
3 | 30 | 0 | 1 | 170 | 237 | 0 | 1 | 170 | 0 | 0.0 | ? | ? | 6 | 0 |
4 | 31 | 0 | 2 | 100 | 219 | 0 | 1 | 150 | 0 | 0.0 | ? | ? | ? | 0 |
sdf = spark.createDataFrame(df)
sdf.show()
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+----------+
|age|sex| cp|trestbps|chol|fbs|restecg|thalach|exang|oldpeak|slope| ca|thal|num |
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+----------+
| 28| 1| 2| 130| 132| 0| 2| 185| 0| 0.0| ?| ?| ?| 0|
| 29| 1| 2| 120| 243| 0| 0| 160| 0| 0.0| ?| ?| ?| 0|
| 29| 1| 2| 140| ?| 0| 0| 170| 0| 0.0| ?| ?| ?| 0|
| 30| 0| 1| 170| 237| 0| 1| 170| 0| 0.0| ?| ?| 6| 0|
| 31| 0| 2| 100| 219| 0| 1| 150| 0| 0.0| ?| ?| ?| 0|
| 32| 0| 2| 105| 198| 0| 0| 165| 0| 0.0| ?| ?| ?| 0|
| 32| 1| 2| 110| 225| 0| 0| 184| 0| 0.0| ?| ?| ?| 0|
| 32| 1| 2| 125| 254| 0| 0| 155| 0| 0.0| ?| ?| ?| 0|
| 33| 1| 3| 120| 298| 0| 0| 185| 0| 0.0| ?| ?| ?| 0|
| 34| 0| 2| 130| 161| 0| 0| 190| 0| 0.0| ?| ?| ?| 0|
| 34| 1| 2| 150| 214| 0| 1| 168| 0| 0.0| ?| ?| ?| 0|
| 34| 1| 2| 98| 220| 0| 0| 150| 0| 0.0| ?| ?| ?| 0|
| 35| 0| 1| 120| 160| 0| 1| 185| 0| 0.0| ?| ?| ?| 0|
| 35| 0| 4| 140| 167| 0| 0| 150| 0| 0.0| ?| ?| ?| 0|
| 35| 1| 2| 120| 308| 0| 2| 180| 0| 0.0| ?| ?| ?| 0|
| 35| 1| 2| 150| 264| 0| 0| 168| 0| 0.0| ?| ?| ?| 0|
| 36| 1| 2| 120| 166| 0| 0| 180| 0| 0.0| ?| ?| ?| 0|
| 36| 1| 3| 112| 340| 0| 0| 184| 0| 1.0| 2| ?| 3| 0|
| 36| 1| 3| 130| 209| 0| 0| 178| 0| 0.0| ?| ?| ?| 0|
| 36| 1| 3| 150| 160| 0| 0| 172| 0| 0.0| ?| ?| ?| 0|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+----------+
only showing top 20 rows
sdf.corr('age', 'age')
1.0