我有一个pysparkDataframe,其模式如下所示:
root
|-- query: string (nullable = true)
|-- collect_list(docId): array (nullable = true)
| |-- element: string (containsNull = true)
|-- prod_count_dict: map (nullable = true)
| |-- key: string
| |-- value: integer (valueContainsNull = true)
数据框如下所示:
+--------------------+--------------------+--------------------+
| query| collect_list(docId)| prod_count_dict|
+--------------------+--------------------+--------------------+
|1/2 inch plywood ...|[471097-153-12CC,...|[530320-62634-100...|
| 1416445|[1416445-83-HHM5S...|[1054482-2251-FFC...
请注意,列 prod_count_dict
是一个包含键值对的字典,如:
{x: 12, a: 16, b:1, f:3, ....}
我想做的是我只想选择 keys
的 top n
最大的 values
从key:value对,并将其存储在另一列中,作为与该行对应的列表,如:[x,a,…]。
我尝试了下面的代码,但它给了我一个错误,有没有办法我可以解决这个特殊的问题?
@F.udf(StringType())
def create_label(x):
# If the length of dictionary is less then 20, I want to return the keys of all the items in the dict.
if len(x) >= 20:
val_sort = sorted(list(x.values()), reverse = True)
cutoff = {k: v for (k, v) in x.items() if v > val_sort[20]}
return cutoff.keys()
else:
return x.keys()
label_df = label_count_df.withColumn("label", create_label("prod_count_dict"))
label_df.show()
2条答案
按热度按时间huwehgph1#
首先我要把这句话爆了:
之后,可以使用window函数获取每个键的前n个值
b09cbbtk2#
你写的自定义项是正确的。您只需更改实际使用的代码。如果您使用
.map
在rdd
:您需要更改的部分是:
这应该管用。