PySpark Statistics Package

11.02.2021

Intro

PySpark provides a stats library located in pyspark.ml.stat that gives us a few tests and classes to use to do common statistical flows. The are based on vectors and provide scalable operations. In this article, we will learn how to use PySpark's stats package.

Setting Up

The quickest way to get started working with python is to use the following docker compose file. Simple create a docker-compose.yml, paste the following code, then run docker-compose up. You will then see a link in the console to open up and access a jupyter notebook.

version: '3'
services:
  spark:
    image: jupyter/pyspark-notebook
    ports:
      - "8888:8888"
      - "4040-4080:4040-4080"
    volumes:
      - ./notebooks:/home/jovyan/work/notebooks/

ChiSquareTest

Let's start by creating a Spark Session.

from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()
from pyspark.ml.linalg import Vectors
from pyspark.ml.stat import ChiSquareTest
dataset = [[0, Vectors.dense([0, 0, 1])],

           [0, Vectors.dense([1, 0, 1])],

           [1, Vectors.dense([2, 1, 1])],

           [1, Vectors.dense([3, 1, 1])]]
dataset = spark.createDataFrame(dataset, ["label", "features"])
chiSqResult = ChiSquareTest.test(dataset, 'features', 'label')
chiSqResult.select("degreesOfFreedom").collect()[0]
Row(degreesOfFreedom=[3, 1, 0])
chiSqResult = ChiSquareTest.test(dataset, 'features', 'label', True)
chiSqResult.show()
+------------+-------------------+----------------+---------+
|featureIndex|             pValue|degreesOfFreedom|statistic|
+------------+-------------------+----------------+---------+
|           0| 0.2614641299491107|               3|      4.0|
|           1|0.04550026389635764|               1|      4.0|
|           2|                1.0|               0|      0.0|
+------------+-------------------+----------------+---------+
row = chiSqResult.orderBy("featureIndex").collect()
row[0].statistic
4.0

Correlation

from pyspark.ml.stat import Correlation

pearsonCorr = Correlation.corr(dataset, 'features', 'pearson').collect()
pearsonCorr

pearsonCorr[0][0]
21/10/30 04:31:43 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeSystemBLAS
21/10/30 04:31:43 WARN BLAS: Failed to load implementation from: com.github.fommil.netlib.NativeRefBLAS
21/10/30 04:31:43 WARN PearsonCorrelation: Pearson correlation matrix contains NaN values.





DenseMatrix(3, 3, [1.0, 0.8944, nan, 0.8944, 1.0, nan, nan, nan, 1.0], False)
spearmanCorr = Correlation.corr(dataset, 'features', method='spearman').collect()
spearmanCorr
spearmanCorr[0][0]
21/10/30 04:32:10 WARN PearsonCorrelation: Pearson correlation matrix contains NaN values.





DenseMatrix(3, 3, [1.0, 0.8944, nan, 0.8944, 1.0, nan, nan, nan, 1.0], False)

KolmogorovSmirnovTest¶

from pyspark.ml.stat import KolmogorovSmirnovTest
dataset = [[-1.0], [0.0], [1.0]]

dataset = spark.createDataFrame(dataset, ['sample'])

ksResult = KolmogorovSmirnovTest.test(dataset, 'sample', 'norm', 0.0, 1.0).first()
ksResult
Row(pValue=0.9999753186701124, statistic=0.1746780794018764)

MultivariateGaussian

from pyspark.ml.stat import MultivariateGaussian
from pyspark.ml.linalg import DenseMatrix, Vectors

m = MultivariateGaussian(Vectors.dense([11,12]), DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0)))
(m.mean, m.cov.toArray())
(DenseVector([11.0, 12.0]),
 array([[1., 5.],
        [3., 2.]]))

Summarizer

from pyspark.ml.stat import Summarizer

from pyspark.sql import Row

from pyspark.ml.linalg import Vectors

summarizer = Summarizer.metrics("mean", "count")

sc = spark.sparkContext
df = sc.parallelize([Row(weight=1.0, features=Vectors.dense(1.0, 1.0, 1.0)),
                     Row(weight=0.0, features=Vectors.dense(1.0, 2.0, 3.0))]).toDF()
df.select(summarizer.summary(df.features, df.weight)).show(truncate=False)
+-----------------------------------+
|aggregate_metrics(features, weight)|
+-----------------------------------+
|{[1.0,1.0,1.0], 1}                 |
+-----------------------------------+



                                                                                
df.select(summarizer.summary(df.features)).show(truncate=False)
+--------------------------------+
|aggregate_metrics(features, 1.0)|
+--------------------------------+
|{[1.0,1.5,2.0], 2}              |
+--------------------------------+
import pandas as pd

# https://stackoverflow.com/questions/57014043/reading-data-from-url-using-spark-databricks-platform

dataurl = 'https://raw.githubusercontent.com/drujensen/heart-disease/master/data/heart-disease.csv'
df = pd.read_csv(dataurl)
df.head()
age sex cp trestbps chol fbs restecg thalach exang oldpeak slope ca thal num
0 28 1 2 130 132 0 2 185 0 0.0 ? ? ? 0
1 29 1 2 120 243 0 0 160 0 0.0 ? ? ? 0
2 29 1 2 140 ? 0 0 170 0 0.0 ? ? ? 0
3 30 0 1 170 237 0 1 170 0 0.0 ? ? 6 0
4 31 0 2 100 219 0 1 150 0 0.0 ? ? ? 0

sdf = spark.createDataFrame(df)
sdf.show()
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+----------+
|age|sex| cp|trestbps|chol|fbs|restecg|thalach|exang|oldpeak|slope| ca|thal|num       |
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+----------+
| 28|  1|  2|     130| 132|  0|      2|    185|    0|    0.0|    ?|  ?|   ?|         0|
| 29|  1|  2|     120| 243|  0|      0|    160|    0|    0.0|    ?|  ?|   ?|         0|
| 29|  1|  2|     140|   ?|  0|      0|    170|    0|    0.0|    ?|  ?|   ?|         0|
| 30|  0|  1|     170| 237|  0|      1|    170|    0|    0.0|    ?|  ?|   6|         0|
| 31|  0|  2|     100| 219|  0|      1|    150|    0|    0.0|    ?|  ?|   ?|         0|
| 32|  0|  2|     105| 198|  0|      0|    165|    0|    0.0|    ?|  ?|   ?|         0|
| 32|  1|  2|     110| 225|  0|      0|    184|    0|    0.0|    ?|  ?|   ?|         0|
| 32|  1|  2|     125| 254|  0|      0|    155|    0|    0.0|    ?|  ?|   ?|         0|
| 33|  1|  3|     120| 298|  0|      0|    185|    0|    0.0|    ?|  ?|   ?|         0|
| 34|  0|  2|     130| 161|  0|      0|    190|    0|    0.0|    ?|  ?|   ?|         0|
| 34|  1|  2|     150| 214|  0|      1|    168|    0|    0.0|    ?|  ?|   ?|         0|
| 34|  1|  2|      98| 220|  0|      0|    150|    0|    0.0|    ?|  ?|   ?|         0|
| 35|  0|  1|     120| 160|  0|      1|    185|    0|    0.0|    ?|  ?|   ?|         0|
| 35|  0|  4|     140| 167|  0|      0|    150|    0|    0.0|    ?|  ?|   ?|         0|
| 35|  1|  2|     120| 308|  0|      2|    180|    0|    0.0|    ?|  ?|   ?|         0|
| 35|  1|  2|     150| 264|  0|      0|    168|    0|    0.0|    ?|  ?|   ?|         0|
| 36|  1|  2|     120| 166|  0|      0|    180|    0|    0.0|    ?|  ?|   ?|         0|
| 36|  1|  3|     112| 340|  0|      0|    184|    0|    1.0|    2|  ?|   3|         0|
| 36|  1|  3|     130| 209|  0|      0|    178|    0|    0.0|    ?|  ?|   ?|         0|
| 36|  1|  3|     150| 160|  0|      0|    172|    0|    0.0|    ?|  ?|   ?|         0|
+---+---+---+--------+----+---+-------+-------+-----+-------+-----+---+----+----------+
only showing top 20 rows
sdf.corr('age', 'age')
1.0