Java UDF with pyspark

2020-11-03
Python Spark



Table of Contents

0. Problem

When using spark it's always advised to use the spark functions. However sometimes we need to use more advanced functions or we need external libraries.

As an example we might want to use Uber H3 library to work with geospatial data.

The library is implemented in multiple languages:

In this post we will explore different options for using this function with spark.

1. Creating a python UDF

The first option is to create a python UDF using the h3 python library. Before creating the UDF we can try the library with:

import h3
h3.geo_to_h3(0, 0, 8)
Out: '88754e6499fffff'

To create the UDF we only need to use the udf decorator and specify the output type. It's also a good idea to handle cases where latitude or longitude might be null/None.

import pyspark.sql.functions as F
import pyspark.sql.types as T

@F.udf(T.StringType())
def geo_to_h3_py(latitude, longitude, resolution):
    if latitude is None or longitude is None:
        return None

    return h3.geo_to_h3(latitude, longitude, resolution)

Once it's defined it can be called with:

(
    sdf
    .withColumn("h8", geo_to_h3_py("latitude", "longitude", F.lit(8)))
    .show()
)

2. Use a Java function with spark

2.1. Use H3 directly

The first thing needed to work with the Java H3 library is to download the jar. It can be downloaded from maven/h3.

Once you have the jar file you need to add it to $JAVA_HOME/jars folder so that spark can use it without extra configuration.

Then you can directly access the java code and create the h3 instance and then call the geoToH3Address function.

h3 = spark.sparkContext._jvm.com.uber.h3core.H3Core.newInstance()
h3.geoToH3Address(0.0, 0.0, 8)
Out: '88754e6499fffff'

However it is not possible to create a UDF using the java instance. The problem is that it's a java object so it can't be pickled.

2.2. Create a java callable function

Let's create a simple java static function that we can use without needing to create an instance.

We can create the file SimpleH3.java inside the path /com/villoro/ (so that it's better organized). Remember to define the package so that it matches the path.

2.2.1 Create the java file

/com/villoro/SimpleH3.java
package com.villoro;

import java.io.IOException;
import com.uber.h3core.H3Core;

class SimpleH3
{ 

    private static H3Core h3;

    public static String toH3Address(Double longitude, Double latitude, int resolution){

        // Lazy instantiation
        if (h3 == null) {
            try {
                h3 = H3Core.newInstance();
            }
            catch(IOException e) {
                return null;
            }
        }

        // Check that coordinates are
        if (longitude == null || latitude == null) {
            return null;
        } else {
            return h3.geoToH3Address(longitude, latitude, resolution);
        }
    }
} 

The only different thing from the python implementation is that we are doing a lazy instantiation. That means that the first time the function is run it will create the h3 instance. But the next times it will reuse the same instance.

Lazy instantiation will make the code faster.

2.2.2. Compile the code

Once we have our java code we need to compile it.

Since we are importing things from com.uber.h3core.H3Core we need that jar. You can copy the jar you downloaded and save it as jars/h3.jar.

Then you can compile the code with:

javac -cp "jars/h3.jar" com\villoro\*.java

2.2.3. Create the jar

To create the jar we can simple specify that we want to create the SimpleH3.jar with all .class files from com\villoro\ with:

jar cvf SimpleH3.jar com\villoro\*.class

Once you created the jar, remeber to copy it to $JAVA_HOME/jars.

2.2.4. Call the static java function with spark

We can do something similar to call the new static function:

# Using our static function
spark.sparkContext._jvm.com.villoro.SimpleH3.toH3Address(0.0, 0.0, 8)

However this can't still be pickled since it's a java class. So we can't create a UDF with python.

3. Create a Java UDF

The solution will be to create the UDF with java.

The first thing you will need is to get the spark-sql.jar from $JAVA_HOME/jars and copy it to com\villoro\. Without this jar you won't be able to compile the UDF with java.

3.1. Create the java function

Then to create the UDF you need to implement one of the UDFX class. Since it will have three imputs we will use UDF3. We will need to declare all the input types (Double, Double and Integer) as well as the ouput (String). So it will extend UDF3<Double, Double, Integer, String>.

The rest of the code it's really similar to the previous function.

/com/villoro/toH3AddressUDF.java
package com.villoro;

import java.io.IOException;
import com.uber.h3core.H3Core;
import org.apache.spark.sql.api.java.UDF3;

public class toH3AddressUDF implements UDF3<Double, Double, Integer, String> {

    private H3Core h3;

    @Override
    public String call(Double longitude, Double latitude, Integer resolution) throws Exception {

        // Lazy instantiation
        if (h3 == null) {
            try {
                h3 = H3Core.newInstance();
            }
            catch(IOException e) {
                return null;
            }
        }

        // Check that coordinates are
        if (longitude == null || latitude == null) {
            return null;
        } else {
            return h3.geoToH3Address(longitude, latitude, resolution);
        }

    }
}

3.2. Compile the code

The command is really similar to the one we used before:

javac -cp "jars/h3.jar;jars/spark-sql.jar" com\villoro\*.java

Make sure the name of the jars you are importing match.

3.3. Create the jar

We can use exactly the same command as before:

jar cvf SimpleH3.jar com\villoro\*.class

Once you created the jar, remeber to overwrite it in $JAVA_HOME/jars.

3.4. Call the java UDF with spark

Before calling the function we need to register it with:

spark.udf.registerJavaFunction("geo_to_h3", "com.villoro.toH3AddressUDF", T.StringType())

Then we can call it using F.expr:

(
    sdf
    .withColumn("h8", F.expr("geo_to_h3(latitude, longitude, 8)"))
    .show()
)

And we can also use it with SQL:

spark.sql("""
    SELECT
        latitude, longitude, geo_to_h3(latitude, longitude, 8)
    FROM my_table
""").show()

4. Performance results

In order to test it I created 6 datasets with different sizes. The total execution times for the python udf and java udf are:

num_rows python [s] java [s]
10^3 2.348551 0.492574
10^4 3.140185 0.579370
10^5 3.876405 0.795821
10^6 5.552880 1.958956
10^7 24.151182 11.507726
10^8 195.032587 88.875919

As we expected the java_udf is faster than the python_udf.

If we take a closer look at the values we can see that creating the python_udf has a cost. That is why the results are so bad for the small datasets. Once the dataset is bigger that cost is no longer visible but there is a cost of moving the information from python to java. That is why the pandas_udf is around 2 times slower.