BERTopic 减少离群值的嵌入策略在设置自定义标签时抛出错误,

wsxa1bj1  于 7个月前  发布在  其他
关注(0)|答案(7)|浏览(52)

我一直关注着你关于如何使用llama获取更好的主题名称的教程。
你和我之间的区别在于,我使用的是阿里巴巴的Qwen 7b模型,而我认为beats是7b或13b模型。在使用embeddings策略进行离群值减少后,我设置了标签。
问题是:如果我使用embeddings减少离群值,-1主题会消失,因此会出现错误:Make sure that topic_labels contains the same number of labels as that there are topics.。如果我使用c-tf-idfdistributions策略减少离群值,就不会出现问题。
你有什么建议吗?
这是代码:

## Embedding model
embedding_model = SentenceTransformer(
    "BAAI/bge-large-en"
)

embeddings = embedding_model.encode(
    docs, normalize_embeddings=True, device="cuda:0", show_progress_bar=True
)

## Representation Model
# MMR
mmr = MaximalMarginalRelevance(diversity=0.7)

# KeyBert inspired
kbi = KeyBERTInspired()

 

# Generative model
model_id = "Qwen/Qwen-7B"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
generation_config = transformers.GenerationConfig.from_pretrained(model_id)

 

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_id,
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    device_map="cuda:2",
)
model.eval()
model.tie_weights()

generator = transformers.pipeline(
    model=model,
    tokenizer=tokenizer,
    task="text-generation",
    temperature=0.01,
    max_new_tokens=50,
    repetition_penalty=1.15,
    top_p=0.95,
    generation_config=generation_config,
)

 

# Prompt
system_prompt = """
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant for labeling topics.
<</SYS>>
"""

# Example prompt demonstrating the output we are looking for
example_prompt = """
I have a topic that contains the following documents:
- Traditional diets in most cultures were primarily plant-based with a little meat on top, but with the rise of industrial style meat production and factory farming, meat has become a staple food.
- Meat, but especially beef, is the word food in terms of emissions.
- Eating meat doesn't make you a bad person, not eating meat doesn't make you a good one.

The topic is described by the following keywords: 'meat, beef, eat, eating, emissions, steak, food, health, processed, chicken'.

Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
[/INST] Environmental impacts of eating meat
"""

 

main_prompt = """
[INST]
I have a topic that contains the following documents:
[DOCUMENTS]
The topic is described by the following keywords: '[KEYWORDS]'.
Based on the information about the topic above, please create a short label of this topic. Make sure you to only return the label and nothing more.
[/INST]
"""
prompt = system_prompt + example_prompt + main_prompt

# Text generation with Qwen
qwen = TextGeneration(generator, prompt=prompt)

# All representation models
representation_model = {
    "KeyBERT": kbi,
    "Qwen": qwen,
    "MMR": mmr,
}

# UMAP
umap_model = UMAP(
    n_neighbors=15, n_components=5, min_dist=0.0, metric="cosine", random_state=50
)

# HDBSCAN
hdbscan_model = HDBSCAN(
    core_dist_n_jobs=-1,
    min_cluster_size=20,
    metric="euclidean",
    cluster_selection_method="leaf",
    prediction_data=True,
)

 

## Topic Model
topic_model = BERTopic(
    # Sub-models
    embedding_model=embedding_model,
    umap_model=umap_model,
    hdbscan_model=hdbscan_model,
    representation_model=representation_model,
    # Hyperparameters
    top_n_words=10,
    verbose=True,
    nr_topics="auto",
)

# Train model
topics, probs = topic_model.fit_transform(docs, embeddings)

# Reduce outliers
new_topics = topic_model.reduce_outliers(
    docs, topics, probabilities=probs, strategy="embeddings"
)

topic_model.update_topics(docs, topics=new_topics)

 
# Set LLM labels
qwen_labels = [
    label[0][0].split("\n")[0].strip()
    for label in topic_model.get_topics(full=True)["Qwen"].values()
]

topic_model.set_topic_labels(qwen_labels)
67up9zun

67up9zun1#

这可能只是由于通过离群值减少导致的-1类问题。我建议采取以下措施:

# Reduce outliers
new_topics = topic_model.reduce_outliers(
    docs, topics, probabilities=probs, strategy="embeddings"
)
topic_model.update_topics(docs, topics=new_topics)

# Update the attribute that checks whether there are still outliers
topic_model._outliers = 0
 
# Set LLM labels
qwen_labels = [
    label[0][0].split("\n")[0].strip()
    for label in topic_model.get_topics(full=True)["Qwen"].values()
]

topic_model.set_topic_labels(qwen_labels)

我相信这是一个已知的问题,有一个PR可用,我需要更深入地检查。

mbyulnm0

mbyulnm02#

谢谢,但仍然抛出相同的错误。

k5ifujac

k5ifujac3#

请检查qwen_labels是否确实包含的标签数少于topic_model.topic_labels_中的标签数。

uurity8g

uurity8g4#

这是另一种情况:qwen_labels 的标签数量比 topic_model.topic_labels_ 多一个。

bprjcwpo

bprjcwpo5#

在这种情况下,我建议检查qwen_labels的顺序是否与topic_model.topic_labels_topic_model.custom_labels_匹配。我预计qwen_labels的输入有一个多余的标签,应该删除。我认为这可能是可以删除的异常类,但你必须进行检查。

alen0pnh

alen0pnh6#

你好,
我遇到了同样的问题。
如何从qwen_lables中移除离群值类?

7y4bm7vi

7y4bm7vi7#

@Keamww2021 Simple remove the first outlier label from the list and I believe it should work. Do note though that it is difficult to say without seeing your exact code/versions/environment/etc.

相关问题