更新时间:2022-02-07 09:20:19
这有效:
from pyspark.sql import functions as F, Row, SparkSession, SQLContext, Window
from pyspark.sql.types import BooleanType
spark = (SparkSession.builder
.master("local")
.appName("Octopus")
.config('spark.sql.autoBroadcastJoinThreshold', -1)
.getOrCreate())
input_rows = [Row(idx=0, interval_start='2018-01-01 00:00:00', interval_end='2018-01-04 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: (1,4) and (2,3) and (3,5) and rate=10/20
Row(idx=0, interval_start='2018-01-02 00:00:00', interval_end='2018-01-03 00:00:00', rate=10, updated_at='2021-02-25 00:00:00'), # OVERLAP: full overlap for (2,3) with (1,4)
Row(idx=0, interval_start='2018-01-03 00:00:00', interval_end='2018-01-05 00:00:00', rate=20, updated_at='2021-02-20 00:00:00'), # OVERLAP: (3,5) and (1,4) and rate=10/20
Row(idx=0, interval_start='2018-01-06 00:00:00', interval_end='2018-01-07 00:00:00', rate=30, updated_at='2021-02-25 00:00:00'), # NO OVERLAP: hole between (5,6)
Row(idx=0, interval_start='2018-01-07 00:00:00', interval_end='2018-01-08 00:00:00', rate=30, updated_at='2021-02-25 00:00:00')] # NO OVERLAP
df = spark.createDataFrame(input_rows)
df.show()
# Compute overlapping intervals
sc = spark.sparkContext
sql_context = SQLContext(sc, spark)
def overlap(start_first, end_first, start_second, end_second):
return ((start_first < start_second < end_first) or (start_first < end_second < end_first)
or (start_second < start_first < end_second) or (start_second < end_first < end_second))
sql_context.registerFunction('overlap', overlap, BooleanType())
df.registerTempTable("df1")
df.registerTempTable("df2")
df = df.cache()
overlap_df = spark.sql("""
SELECT df1.idx, df1.interval_start, df1.interval_end, df1.rate AS rate FROM df1 JOIN df2
ON df1.idx == df2.idx
WHERE overlap(df1.interval_start, df1.interval_end, df2.interval_start, df2.interval_end)
""")
overlap_df = overlap_df.cache()
# Compute NON overlapping intervals
non_overlap_df = df.join(overlap_df, ['interval_start', 'interval_end'], 'leftanti')
# Stack overlapping points
interval_point = overlap_df.select('interval_start').union(overlap_df.select('interval_end'))
interval_point = interval_point.withColumnRenamed('interval_start', 'p').distinct().sort('p')
# Construct continuous overlapping intervals
w = Window.rowsBetween(1, Window.unboundedFollowing)
interval_point = interval_point.withColumn('interval_end', F.min('p').over(w)).dropna(subset=['p', 'interval_end'])
interval_point = interval_point.withColumnRenamed('p', 'interval_start')
# Stack continuous overlapping intervals and non overlapping intervals
df3 = interval_point.select('interval_start', 'interval_end').union(non_overlap_df.select('interval_start', 'interval_end'))
# Point in interval range join
# https://docs.databricks.com/delta/join-performance/range-join.html
df3.registerTempTable("df3")
df.registerTempTable("df")
sql = """SELECT df3.interval_start, df3.interval_end, df.rate, df.updated_at
FROM df3 JOIN df ON df3.interval_start BETWEEN df.interval_start and df.interval_end - INTERVAL 1 seconds"""
df4 = spark.sql(sql)
df4.sort('interval_start').show()
# select non overlapped intervals and keep most up to date rate value for overlapping intervals
(df4.groupBy('interval_start', 'interval_end')
.agg(F.max(F.struct('updated_at', 'rate'))['rate'].alias('rate'))
.orderBy("interval_start")).show()
+-------------------+-------------------+----+
| interval_start| interval_end|rate|
+-------------------+-------------------+----+
|2018-01-01 00:00:00|2018-01-02 00:00:00| 10|
|2018-01-02 00:00:00|2018-01-03 00:00:00| 10|
|2018-01-03 00:00:00|2018-01-04 00:00:00| 10|
|2018-01-04 00:00:00|2018-01-05 00:00:00| 20|
|2018-01-06 00:00:00|2018-01-07 00:00:00| 30|
|2018-01-07 00:00:00|2018-01-08 00:00:00| 30|
+-------------------+-------------------+----+