更新时间:2023-11-18 22:30:28
像这样的东西应该这样做。
Something like this should do it.
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.SparkConf
import org.apache.spark.rdd.RDD
val conf = new SparkConf().setAppName("spark-scratch").setMaster("local")
val sco= new SparkContext(conf)
// k is the number of nearest neighbors required
val k = 3
// generate 5 rows of two-dimensional coordinates
val rows = List.fill(5)(List.fill(2)(Math.random))
val dataRDD = sco.parallelize(rows, 1)
// No need for the sqrt as we're just comparing them
def euclidean(a:List[Double], b:List[Double]) =
(a zip b) map {case (x:Double, y:Double) => (x-y)*(x-y)} sum
// get all pairs
val pairs = dataRDD.cartesian(dataRDD)
// case class to keep things a bit neater
// the neighbor, and its distance from the current point
case class Entry(neighbor: List[Double], dist:Double)
// map the second element to the element and distance from the first
val pairsWithDist = pairs.map {case (x, y) => (x, Entry(y, euclidean(x,y)))}
// merge a row of pairsWithDist with the ResultRow for this point
def mergeOne(u: List[Entry], v:Entry) = (v::u).sortBy{_.dist}.take(k)
// merge two results from different partitions
def mergeList(u: List[Entry], v:List[Entry]) = (u:::v).sortBy{_.dist}.take(k)
val nearestNeighbors = pairsWithDist
.aggregateByKey(List[Entry]())(mergeOne, mergeList)