且构网

分享程序员开发的那些事...
且构网 - 分享程序员编程开发的那些事

使用列条件随机采样 Pyspark 数据框

更新时间:2023-11-18 23:09:28

您可以使用 sampleBy() 返回一个分层样本,根据每个层给出的分数没有替换.

>>>从 pyspark.sql.functions 导入列>>>数据集 = sqlContext.range(0, 100).select((col("id") % 3).alias("result"))>>>采样= dataset.sampleBy(结果",分数={0:0.1,1:0.2},种子=0)>>>sampled.groupBy("result").count().orderBy("key").show()+------+-----+|结果|计数|+------+-----+|0|5||1|9|+------+-----+

I'm trying to randomly sample a Pyspark dataframe where a column value meets a certain condition. I would like to use the sample method to randomly select rows based on a column value. Let's say I have the following data frame:

+---+----+------+-------------+------+
| id|code|   amt|flag_outliers|result|
+---+----+------+-------------+------+
|  1|   a|  10.9|            0|   0.0|
|  2|   b|  20.7|            0|   0.0|
|  3|   c|  30.4|            0|   1.0|
|  4|   d| 40.98|            0|   1.0|
|  5|   e| 50.21|            0|   2.0|
|  6|   f|  60.7|            0|   2.0|
|  7|   g|  70.8|            0|   2.0|
|  8|   h| 80.43|            0|   3.0|
|  9|   i| 90.12|            0|   3.0|
| 10|   j|100.65|            0|   3.0|
+---+----+------+-------------+------+

I would like to sample only 1(or any certain amount) of each of the 0, 1, 2, 3 based on the result column so I'd end up with this:

+---+----+------+-------------+------+
| id|code|   amt|flag_outliers|result|
+---+----+------+-------------+------+
|  1|   a|  10.9|            0|   0.0|
|  3|   c|  30.4|            0|   1.0|
|  5|   e| 50.21|            0|   2.0|
|  8|   h| 80.43|            0|   3.0|
+---+----+------+-------------+------+

Is there a good programmatic way to achieve this, i.e take the same number of rows for each of the values given in a certain column? Any help is really appreciated!

You can use sampleBy() which returns a stratified sample without replacement based on the fraction given on each stratum.

>>> from pyspark.sql.functions import col
>>> dataset = sqlContext.range(0, 100).select((col("id") % 3).alias("result"))
>>> sampled = dataset.sampleBy("result", fractions={0: 0.1, 1: 0.2}, seed=0)
>>> sampled.groupBy("result").count().orderBy("key").show()

+------+-----+
|result|count|
+------+-----+
|     0|    5|
|     1|    9|
+------+-----+