Scala UDF in Pyspark

Spark Dataframe provides a rich set of UDF. Sometimes, you need more to handle some special cases that can be handled using UDF in simpler ways. While using Pyspark, it provides a mechanism to define a UDF in python, but UDFs defined in python will be executed in python run time rather than executor JVM of Spark, that handoffs between JVM and python VM makes the execution time longer. Processing can be done faster if the UDF is created using Scala and called from pyspark just like existing spark UDFs. Benefit will be faster execution time, for example, 28 mins vs 4.2 mins.

For the following demo I used the 8 cores, 64 GB ram machine using spark 2.2.0.

Here is a sample case based on RITA airline dataset available here.

There are 5 delay fields in the data set - CarrierDelay, WeatherDelay, NASDelay, SecurityDelay, LateAircraftDelay. If the flight is delayed, at least one of these are populated. Find the cause of the delay based on the max non zero value among these field.

Load the data.

Note, I have converted the data into parquet format from the origin csv format. I partitioned the data by year.

path = "RITA/data-parquet"
data = spark.read.load(path)
data.count()

107,511,022

import re
import numpy as np
import pyspark.sql.functions as F

is_number = re.compile(r"^\d+(\.\d+)?$")

fields = "CarrierDelay,WeatherDelay,NASDelay,SecurityDelay,LateAircraftDelay".split(",")
    
def cause_of_delay_(row):
    delays = [row[f] for f in fields]
    delays = ["" if s is None else s.strip() for s in delays]
    delays = [s if is_number.match(s) else "-1" for s in delays]
    delays = [float(s) for s in delays]
    cause = fields[np.argmax(delays)] if max(delays) > 0 else None 
    return cause
cause_of_delay = F.udf(cause_of_delay_)
columns = [F.col(c) for c in data.columns]
data.withColumn("CauseOfDelay", cause_of_delay(F.struct(* columns))).groupBy("causeOfDelay").count().show()

+-----------------+--------+ | causeOfDelay| count| +-----------------+--------+ | WeatherDelay| 633927| | SecurityDelay| 39506| | null|90051857| |LateAircraftDelay| 6368701| | NASDelay| 5593921| | CarrierDelay| 4823110| +-----------------+--------+

To improve the performance of the processing, let's define the UDF in scala and call it from pyspark.

Create a build.sbt

name := "SparkUDFs"
version := "0.1"
scalaVersion := "2.11.8"
libraryDependencies ++= Seq(
  "org.apache.spark" %% "spark-sql"       % "2.2.0"
)

Create scala class - src/main/scala/com/einext/airlines/findDelayCause.scala

package com.einext.airlines
import org.apache.spark.sql.Row
import org.apache.spark.sql.api.java.UDF1
/*
* Usage: find cause of airline delay based on 5 fields. Find the field name which has non-zero highest value.
* Use this UDF for faster execution on larger dataset.
*
* sqlContext.registerJavaFunction("findDelayCause", "com.einext.airlines.findDelayCause")
* data.selectExpr("findDelayCause(struct(*))").show()
*
* */
class findDelayCause extends UDF1[Row, String] {
  @throws[Exception]
  override def call(row: Row): String = {
    val fields: Array[String] = "CarrierDelay,WeatherDelay,NASDelay,SecurityDelay,LateAircraftDelay".split(",").map(_.trim)
    val values = fields.map{field =>
      val v: String = row.getString(row.fieldIndex(field))
      if (v != null && v.trim.length > 0) v.trim.toDouble else -1.0
    }
    var cause:String = null
    if(values.max > 0){
      val i = values.indexOf(values.max)
      cause = fields(i)
    }
    cause
  }
}

Compile and package

$ sbt clean package 

Launch pyspark passing the new jar

$ $SPARK_HOME/bin/pyspark --jars target/scala-2.11/sparkudfs_2.11-0.1.jar
import pyspark.sql.functions as F
path = "RITA/data-parquet"
data = spark.read.load(path)
sqlContext.registerJavaFunction("findDelayCause", "com.einext.airlines.findDelayCause")
data.withColumn("CauseOfDelay", F.expr("findDelayCause(struct(*))")).groupBy("causeOfDelay").count().show()

+-----------------+--------+ | causeOfDelay| count| +-----------------+--------+ | WeatherDelay| 633927| | SecurityDelay| 39506| | null|90051857| |LateAircraftDelay| 6368701| | NASDelay| 5593921| | CarrierDelay| 4823110| +-----------------+--------+

The median value of task duration has dropped from 6.6 mins to 57 secs, resulting in drop of total job processing time from 28 mins (python UDF) to 4.2 mins (scala UDF).