BERTopic 在 _hierarchy.py 中的 custom_labels 的 bug

j5fpnvbx  于 1个月前  发布在  其他
关注(0)|答案(3)|浏览(30)

你好,

在使用方面作为自定义标签时,hierarchy.py似乎存在一个bug。在第151行,fig.layout[axis]["ticktext"]返回一个字符串列表,但是topics在topic_model.topic_aspects中是以整数索引的。我通过在第151行引入类型转换来进行了本地修复:

new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][int(x)] for x in fig.layout[axis]["ticktext"]]
mjqavswn

mjqavswn1#

你能展示一下这个bug的完整示例吗?没有它,我很难复现或提供支持。

jbose2ul

jbose2ul2#

当然,这是它:

from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from bertopic.representation import MaximalMarginalRelevance

from sklearn.datasets import fetch_20newsgroups

docs = fetch_20newsgroups(subset='test',  remove=('headers', 'footers', 'quotes'))['data']

# Additional ways of representing a topic
aspect_model1 = MaximalMarginalRelevance()
aspect_model2 = KeyBERTInspired()

# Add all models together to be run in a single `fit`
representation_model = {
   "Aspect1":  aspect_model1,
   "Aspect2":  aspect_model2 
}
topic_model = BERTopic(representation_model=representation_model,verbose=True).fit(docs)

hierarchical_topics = topic_model.hierarchical_topics(docs)

topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,custom_labels="Aspect1")

我得到错误:

KeyError                                  Traceback (most recent call last)

[<ipython-input-3-0e4010dde1ae>](https://localhost:8080/#) in <cell line: 23>()
     21 hierarchical_topics = topic_model.hierarchical_topics(docs)
     22 
---> 23 topic_model.visualize_hierarchy(hierarchical_topics=hierarchical_topics,custom_labels="Aspect1")

2 frames

[/usr/local/lib/python3.10/dist-packages/bertopic/plotting/_hierarchy.py](https://localhost:8080/#) in <listcomp>(.0)
    149     axis = "yaxis" if orientation == "left" else "xaxis"
    150     if isinstance(custom_labels, str):
--> 151         new_labels = [[[str(x), None]] + topic_model.topic_aspects_[custom_labels][x] for x in fig.layout[axis]["ticktext"]]
    152         new_labels = ["_".join([label[0] for label in labels[:4]]) for labels in new_labels]
    153         new_labels = [label if len(label) < 30 else label[:27] + "..." for label in new_labels]

KeyError: '98'
ifmq2ha2

ifmq2ha23#

是的,可以完全重现这个问题,谢谢!如果你愿意,非常感谢你提交PR #1504 ,如果你有时间的话😄否则,我自己创建PR没有问题。无论如何,感谢你分享这个问题并提出解决方案!

相关问题