BERTopic logger.warning()在topic_model.save()中的格式化问题

y1aodyip  于 23天前  发布在  其他
关注(0)|答案(2)|浏览(23)

你好,马腾,

我尝试保存一个没有嵌入模型指针的主题模型时遇到了以下错误。

from sklearn.datasets import fetch_20newsgroups
from bertopic import BERTopic

# Documents to train on
docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data'][0:500]
topic_model = BERTopic().fit(docs)

topic_model.save("model_dir", serialization="safetensors", save_ctfidf=True, save_embedding_model=False)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[1], line 8
      5 docs = fetch_20newsgroups(subset='all',  remove=('headers', 'footers', 'quotes'))['data'][0:500]
      6 topic_model = BERTopic().fit(docs)
----> 8 topic_model.save("model_dir", serialization="safetensors", save_ctfidf=True, save_embedding_model=False)

File c:\path\lib\site-packages\bertopic\_bertopic.py:2998, in BERTopic.save(self, path, serialization, save_embedding_model, save_ctfidf)
   2996     save_embedding_model = self.embedding_model._hf_model
   2997 elif not save_embedding_model:
-> 2998     logger.warning("You are saving a BERTopic model without explicitly defining an embedding model."
   2999                    "If you are using a sentence-transformers model or a HuggingFace model supported"
   3000                    "by sentence-transformers, please save the model by using a pointer towards that model."
   3001                    "For example, `save_embedding_model=sentence-transformers/all-mpnet-base-v2`", RuntimeWarning)
   3003 # Minimal
   3004 save_utils.save_hf(model=self, save_directory=save_directory, serialization=serialization)

TypeError: warning() takes 2 positional arguments but 3 were given
v2g6jxz6

v2g6jxz61#

感谢您的分享!这个问题似乎可以通过确保传递给 logger.warning 的字符串是一个单一的字符串而不是多个来轻松解决。如果您愿意,我很乐意提交一个 PR。
目前,只需将 save_embedding_model 设置为 True 就可以防止这种情况发生,而且似乎没有任何缺点。

brjng4g3

brjng4g32#

谢谢Maarten,我快要完成今年的工作了,但如果这个职位在明年1月仍然开放,我会提交一份申请。

相关问题