且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

如何在Java / Kotlin中创建一个返回复杂类型的Spark UDF?

更新时间:2022-10-14 22:47:37

TL; DR 该函数应返回类 org.apache.spark.sql的对象。行



Spark提供 UDF 定义的两个主要变体。


  1. udf 使用Scala反射的变体:




    • def udf [RT](f :()⇒RT)(隐式arg0:TypeTag [RT]):UserDefinedFunction

    • def udf [RT,A1](f:(A1)⇒RT)(隐式arg0:TypeTag [RT],arg1:TypeTag [A1]):UserDefinedFunction

    • ...

    • def udf [RT,A1,A2,... ,A10](f:(A1,A2,...,A10)⇒RT)(隐含的arg0: TypeTag [RT],arg1:TypeTag [A1],arg2:TypeTag [A2],...,arg10:TypeTag [A10])



    定义


    作为用户定义函数(UDF)的...参数的Scala闭包。根据Scala闭包的签名自动推断数据类型。


    这些变体在没有原子或代数数据类型的模式的情况下使用。例如,有问题的函数将在Scala中定义:

      case class Price(value:Double ,currency:String)

    val df = Seq(1 USD)。toDF(price)

    val toPrice = udf((s:String)=> ; scala.util.Try {
    s split()match {
    case Array(price,currency)=> Price(price.toDouble,currency)
    }
    } .toOption)

    df.select(toPrice($price))。show
    // + ---------- +
    // | UDF(价格)|
    // + ---------- +
    // | [1.0,USD] |
    // + ---------- +

    此变体返回类型是自动编码的。



    由于它依赖于反射,因此该变体主要用于Scala用户。


  2. udf 提供模式定义的变体(您在此处使用的变体)。此变体的返回类型应与数据集[行] 相同:




    • 正如在另一个答案中指出的那样,您只能使用 SQL类型映射表(原子类型为盒装或非盒装, java.sql.Timestamp / java.sql.Date ,以及高级集合。)


    • 复杂结构(结构 / StructTypes )使用 org.apache.spark.sql.Row 表示。不允许与代数数据类型或等效数据混合。例如(Scala代码)

        struct< _1:int,_2:struct< _1:string,_2:struct< _1:双,_2:INT&GT;&GT;&GT; 

      应表示为

       行(1,行(foo,行(-1.0,42))))



       ( 1,(foo,( -  1,0.4))))



      或任何混合变体,例如



       行(1,行(foo,( -  1,02,4))))




    此变体主要用于确保Java互操作性。



    在这种情况下(相当于有问题的一个),定义应类似于以下定义:



    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions.udf
    import org.apache.spark.sql.Row


    val schema = StructType(Seq(
    StructField(value,DoubleType,false),
    StructField( currency,StringType,false)
    ))

    val toPrice = udf((s:String)=> scala.util.Try {
    s split()match {
    case Array(price,currency)=>行(price.toDouble,货币)
    }
    } .getOrElse(null),架构)

    df.select(toPrice($price))。show
    // + ---------- +
    // | UDF(价格)|
    // + ---------- +
    // | [1.0,USD] |
    // |空|
    // + ---------- +

    不包括所有异常处理的细微差别(通常 UDFs 应该控制 null 输入并按照惯例优雅地处理格式错误的数据)Java等价物应该看起来或多或少像这样:

      UserDefinedFunction price = udf((String s) - &gt ; {
    String [] split = s.split();
    返回RowFactory.create(Double.parseDouble(split [0]),split [1]);
    }, DataTypes.createStructType(new StructField [] {
    DataTypes.createStructField(value,DataTypes.DoubleType,true),
    DataTypes.createStructField(currency,DataTypes.StringType,true)
    }));


上下文



为了给你一些上下文,这个区别也反映在API的其他部分。例如,您可以从架构和一系列创建 DataFrame

  def createDataFrame(rows:List [Row],schema:StructType):DataFrame 

或使用一系列产品的反射

  def createDataFrame [A<:Product](数据:Seq [A])(隐式arg0:TypeTag [A]):DataFrame 

但不支持混合变体。



换句话说,你应该提供可以使用 RowEncoder 进行编码的输入。



当然,你通常不会使用 udf 来执行这样的任务:

  import org.apache.spark.sql.functions._ 

df.withColumn(price,struct(
split($price) ,)(0).cast(double)。别名(price),
split($price,)(1).alias(currency)
))

相关




I'm trying to write an UDF which returns a complex type:

private val toPrice = UDF1<String, Map<String, String>> { s ->
    val elements = s.split(" ")
    mapOf("value" to elements[0], "currency" to elements[1])
}


val type = DataTypes.createStructType(listOf(
        DataTypes.createStructField("value", DataTypes.StringType, false),
        DataTypes.createStructField("currency", DataTypes.StringType, false)))
df.sqlContext().udf().register("toPrice", toPrice, type)

but any time I use this:

df = df.withColumn("price", callUDF("toPrice", col("price")))

I get a cryptic error:

Caused by: org.apache.spark.SparkException: Failed to execute user defined function($anonfun$28: (string) => struct<value:string,currency:string>)
    at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage1.processNext(Unknown Source)
    at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
    at org.apache.spark.sql.execution.WholeStageCodegenExec$$anonfun$10$$anon$1.hasNext(WholeStageCodegenExec.scala:614)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:253)
    at org.apache.spark.sql.execution.SparkPlan$$anonfun$2.apply(SparkPlan.scala:247)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.RDD$$anonfun$mapPartitionsInternal$1$$anonfun$apply$25.apply(RDD.scala:830)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:38)
    at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
    at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:87)
    at org.apache.spark.scheduler.Task.run(Task.scala:109)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:345)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
    at java.lang.Thread.run(Thread.java:748)
Caused by: scala.MatchError: {value=138.0, currency=USD} (of class java.util.LinkedHashMap)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:236)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$StructConverter.toCatalystImpl(CatalystTypeConverters.scala:231)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$CatalystTypeConverter.toCatalyst(CatalystTypeConverters.scala:103)
    at org.apache.spark.sql.catalyst.CatalystTypeConverters$$anonfun$createToCatalystConverter$2.apply(CatalystTypeConverters.scala:379)
    ... 19 more

I tried to use a custom data type:

class Price(val value: Double, val currency: String) : Serializable

with an UDF which returns that type:

private val toPrice = UDF1<String, Price> { s ->
    val elements = s.split(" ")
    Price(elements[0].toDouble(), elements[1])
}

but then I get another MatchError which complains for the Price type.

How do I properly write an UDF which can return a complex type?

TL;DR The function should return an object of class org.apache.spark.sql.Row.

Spark provides two main variants of UDF definitions.

  1. udf variants using Scala reflection:

    • def udf[RT](f: () ⇒ RT)(implicit arg0: TypeTag[RT]): UserDefinedFunction
    • def udf[RT, A1](f: (A1) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1]): UserDefinedFunction
    • ...
    • def udf[RT, A1, A2, ..., A10](f: (A1, A2, ..., A10) ⇒ RT)(implicit arg0: TypeTag[RT], arg1: TypeTag[A1], arg2: TypeTag[A2], ..., arg10: TypeTag[A10])

    which define

    Scala closure of ... arguments as user-defined function (UDF). The data types are automatically inferred based on the Scala closure's signature.

    These variants are used without schema with atomics or algebraic data types. For example the function in question would be defined in Scala:

    case class Price(value: Double, currency: String) 
    
    val df = Seq("1 USD").toDF("price")
    
    val toPrice = udf((s: String) => scala.util.Try { 
      s split(" ") match {
        case Array(price, currency) => Price(price.toDouble, currency)
      }
    }.toOption)
    
    df.select(toPrice($"price")).show
    // +----------+
    // |UDF(price)|
    // +----------+
    // |[1.0, USD]|
    // +----------+
    

    In this variant return type is automatically encoded.

    Due to it's dependence on reflection this variant is intended primarily for Scala users.

  2. udf variants providing schema definition (one you use here). The return type for this variant, should be the same as for Dataset[Row]:

    • As pointed out in the other answer you can use only the types listed in the SQL types mapping table (atomic types either boxed or unboxed, java.sql.Timestamp / java.sql.Date, as well as high level collections).

    • Complex structures (structs / StructTypes) are expressed using org.apache.spark.sql.Row. No mixing with algebraic data types or equivalent is allowed. For example (Scala code)

      struct<_1:int,_2:struct<_1:string,_2:struct<_1:double,_2:int>>>
      

      should be expressed as

      Row(1, Row("foo", Row(-1.0, 42))))
      

      not

      (1, ("foo", (-1.0, 42))))
      

      or any mixed variant, like

      Row(1, Row("foo", (-1.0, 42))))
      

    This variant is provided primarily to ensure Java interoperability.

    In this case (equivalent to the one in question) the definition should be similar to the following one:

    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions.udf
    import org.apache.spark.sql.Row
    
    
    val schema = StructType(Seq(
      StructField("value", DoubleType, false),
      StructField("currency", StringType, false)
    ))
    
    val toPrice = udf((s: String) => scala.util.Try { 
      s split(" ") match {
        case Array(price, currency) => Row(price.toDouble, currency)
      }
    }.getOrElse(null), schema)
    
    df.select(toPrice($"price")).show
    // +----------+
    // |UDF(price)|
    // +----------+
    // |[1.0, USD]|
    // |      null|
    // +----------+
    

    Excluding all the nuances of exception handling (in general UDFs should contr ol for null input and by convention gracefully handle malformed data) Java equivalent should look more or less like this:

    UserDefinedFunction price = udf((String s) -> {
        String[] split = s.split(" ");
        return RowFactory.create(Double.parseDouble(split[0]), split[1]);
    }, DataTypes.createStructType(new StructField[]{
        DataTypes.createStructField("value", DataTypes.DoubleType, true),
        DataTypes.createStructField("currency", DataTypes.StringType, true)
    }));
    

Context:

To give you some context this distinction is reflected in the other parts of the API as well. For example, you can create DataFrame from a schema and a sequence of Rows:

def createDataFrame(rows: List[Row], schema: StructType): DataFrame 

or using reflection with a sequence of Products

def createDataFrame[A <: Product](data: Seq[A])(implicit arg0: TypeTag[A]): DataFrame 

but no mixed variants are supported.

In other words you should provide input that can be encoded using RowEncoder.

Of course you wouldn't normally use udf for the task like this one:

import org.apache.spark.sql.functions._

df.withColumn("price", struct(
  split($"price", " ")(0).cast("double").alias("price"),
  split($"price", " ")(1).alias("currency")
))

Related: