更新时间:2023-02-15 09:20:34
使用Python UDF:
With Python UDF:
from pyspark.sql.functions import udf, size
from pyspark.sql.types import *
intersect = lambda type: (udf(
lambda x, y: (
list(set(x) & set(y)) if x is not None and y is not None else None),
ArrayType(type)))
df = sc.parallelize([([1, 2, 3], [1, 2]), ([3, 4], [5, 6])]).toDF(["xs", "ys"])
integer_intersect = intersect(IntegerType())
df.select(
integer_intersect("xs", "ys"),
size(integer_intersect("xs", "ys"))).show()
+----------------+----------------------+
|<lambda>(xs, ys)|size(<lambda>(xs, ys))|
+----------------+----------------------+
| [1, 2]| 2|
| []| 0|
+----------------+----------------------+
带文字:
from pyspark.sql.functions import array, lit
df.select(integer_intersect("xs", array(lit(1), lit(5)))).show()
+-------------------------+
|<lambda>(xs, array(1, 5))|
+-------------------------+
| [1]|
| []|
+-------------------------+
或
df.where(size(integer_intersect("xs", array(lit(1), lit(5)))) > 0).show()
+---------+------+
| xs| ys|
+---------+------+
|[1, 2, 3]|[1, 2]|
+---------+------+