BERTopic bug with custom_labels in _topics_over_time.py

kqhtkvqz  于 2个月前  发布在  其他
关注(0)|答案(3)|浏览(26)

你好,
在使用custom_labels时,似乎在_topics_over_time.py中存在一个错误。在第67行,你遍历了topics,但我认为应该是selected_topics。我针对这一行的本地修复是:

topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in selected_topics]

此外,第70行也有问题,因为你在遍历所有主题,而topic_names可能不包含所有主题。我已经将其更改为:

topic_names = {key: topic_names[index] for index, key in enumerate(selected_topics)}
wz3gfoph

wz3gfoph1#

你能展示一个这个问题的示例吗?这样,我就能更容易地看出哪里出了问题,以及它是否影响了代码库的其他部分。

b09cbbtk

b09cbbtk2#

当然,这是它的样子:

from bertopic import BERTopic
from bertopic.representation import KeyBERTInspired
from bertopic.representation import MaximalMarginalRelevance
import pandas as pd
import re

# Additional ways of representing a topic
aspect_model1 = MaximalMarginalRelevance()
aspect_model2 = KeyBERTInspired()

# Add all models together to be run in a single `fit`
representation_model = {
   "Aspect1":  aspect_model1,
   "Aspect2":  aspect_model2 
}
trump = pd.read_csv('https://drive.google.com/uc?export=download&id=1xRKHaP-QwACMydlDnyFPEaFdtskJuBa6')
trump.text = trump.apply(lambda row: re.sub(r"http\S+", "", row.text).lower(), 1)
trump.text = trump.apply(lambda row: " ".join(filter(lambda x:x[0]!="@", row.text.split())), 1)
trump.text = trump.apply(lambda row: " ".join(re.sub("[^a-zA-Z]+", " ", row.text).split()), 1)
trump = trump.loc[(trump.isRetweet == "f") & (trump.text != ""), :]
timestamps = trump.date.to_list()
tweets = trump.text.to_list()

# Create topics over time
model = BERTopic(verbose=True,representation_model=representation_model)
topics, probs = model.fit_transform(tweets)
topics_over_time = model.topics_over_time(tweets, timestamps)
model.visualize_topics_over_time(topics_over_time,custom_labels='Aspect1')

然后我得到了这个错误:

TypeError                                 Traceback (most recent call last)

[<ipython-input-5-d9d3aafd27f3>](https://localhost:8080/#) in <cell line: 29>()
     27 topics, probs = model.fit_transform(tweets)
     28 topics_over_time = model.topics_over_time(tweets, timestamps)
---> 29 model.visualize_topics_over_time(topics_over_time,custom_labels='Aspect1')

1 frames

[/usr/local/lib/python3.10/dist-packages/bertopic/plotting/_topics_over_time.py](https://localhost:8080/#) in visualize_topics_over_time(topic_model, topics_over_time, top_n_topics, topics, normalize_frequency, custom_labels, title, width, height)
     65     # Prepare data
     66     if isinstance(custom_labels, str):
---> 67         topic_names = [[[str(topic), None]] + topic_model.topic_aspects_[custom_labels][topic] for topic in topics]
     68         topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names]
     69         topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names]

TypeError: 'NoneType' object is not iterable

这是因为 topics 没有被初始化,所以我认为应该是 selectec_topics
此外,如果我传递主题:

model.visualize_topics_over_time(topics_over_time,custom_labels="Aspect1",topics=[9, 10, 72, 83, 87, 91])

我得到错误:

---------------------------------------------------------------------------

IndexError                                Traceback (most recent call last)

[<ipython-input-10-49fa7a976f72>](https://localhost:8080/#) in <cell line: 1>()
----> 1 model.visualize_topics_over_time(topics_over_time,custom_labels="Aspect1",topics=[1,2])

2 frames

[/usr/local/lib/python3.10/dist-packages/bertopic/plotting/_topics_over_time.py](https://localhost:8080/#) in <dictcomp>(.0)
     68         topic_names = ["_".join([label[0] for label in labels[:4]]) for labels in topic_names]
     69         topic_names = [label if len(label) < 30 else label[:27] + "..." for label in topic_names]
---> 70         topic_names = {key: topic_names[index] for index, key in enumerate(topic_model.topic_labels_.keys())}
     71     elif topic_model.custom_labels_ is not None and custom_labels:
     72         topic_names = {key: topic_model.custom_labels_[key + topic_model._outliers] for key, _ in topic_model.topic_labels_.items()}

IndexError: list index out of range
afdcj2ne

afdcj2ne3#

Thanks for the code! If you want, a PR would be much appreciated 😄

相关问题