CTranslate2 Whisper批处理生成速度不如循环快

brjng4g3  于 6个月前  发布在  其他
关注(0)|答案(5)|浏览(86)

在CTranslate2的Whisper模型中,批量生成并不比逐个循环快。我在Translator模型上尝试了同样的事情,结果显示批量处理远远优越(快得多)。我使用ct2工具将Whisper小转换为int8。此外,批量处理时GPU内存更高,所以我认为CTranslate2正在进行“适当的”批量处理(而不是循环 Package 器)。以下是我的简单Whisper代码。

import time

import numpy as np
from ctranslate2 import StorageView
from ctranslate2.models import Whisper
from transformers import WhisperProcessor

def make_prompts(tokenizer, n: int) -> list[list[int]]:
    prompt = tokenizer.convert_tokens_to_ids(
        [
            "<|startoftranscript|>",
            "<|en|>",
            "<|transcribe|>",
        ]
    )
    return [prompt] * n

def loop(
    whisper: Whisper,
    features: list[StorageView],
    prompts: list[list[int]],
):
    for feat, prompt in zip(features, prompts):
        _ = whisper.generate(
            feat,
            [prompt],
            return_scores=True,
            return_no_speech_prob=True,
        )

def batch(
    whisper: Whisper,
    features: StorageView,
    prompts: list[list[int]],
):
    _ = whisper.generate(
        features,
        prompts,
        return_scores=True,
        return_no_speech_prob=True,
    )

def main():
    N_SAMPLES = 8
    N_SEC = 5
    SR = 16000

    # load model and processor
    whisper = Whisper("models/whisper-small", device="cuda")
    processor = WhisperProcessor.from_pretrained("openai/whisper-small")
    tokenizer = processor.tokenizer

    # generate required data
    chunks = np.random.random((N_SAMPLES, N_SEC * SR)).astype(np.float32)
    inputs = processor(chunks, return_tensors="np", sampling_rate=SR)
    mels = inputs["input_features"]
    features_loop = [StorageView.from_array(m[None, :]) for m in mels]
    features_batch = StorageView.from_array(mels)
    prompts = make_prompts(tokenizer, N_SAMPLES)

    # warm up
    print("warming up... ", end="", flush=True)
    for _ in range(7):
        loop(whisper, features_loop, prompts)
        batch(whisper, features_batch, prompts)
    print("done")

    N = 20
    print(f"benchmarking each method for {N} iterations")

    # loop time
    t0 = time.perf_counter()
    for _ in range(N):
        loop(whisper, features_loop, prompts)
    elapsed = time.perf_counter() - t0
    print(f"loop time: {elapsed:0.3f} secs")

    # batch time
    t0 = time.perf_counter()
    for _ in range(N):
        batch(whisper, features_batch, prompts)
    elapsed = time.perf_counter() - t0
    print(f"batch time: {elapsed:0.3f} secs")

main()

当我在colab(T4 GPU)上运行代码时,它输出:

benchmarking each method for 20 iterations
loop time: 25.311 secs
batch time: 30.086 secs

有什么办法可以提高Whisper批量生成的速度吗?

wkyowqbh

wkyowqbh1#

当然!很高兴你问了!哈哈。Ctranslate2实际上确实支持真正的批处理,但在C++层面上。我会给你我的仓库,它使用WhisperS2T通过惊人的WhisperS2T以及直接链接到该仓库。据我所知,关于“批处理”处理的讨论已经相当多,但由于仓库的状态,目前还不可行。另一方面,faster-whisper具有WhisperS2T没有的其他功能。请记住,我的仓库有点过时,因为我还没有将其更新为最新的WhisperS2T,所以如果API有任何更改,请咨询上游。
然而,如果你使用我的仓库作为样本脚本并保持版本相同,你应该没问题。我现在对WhisperS2T有很多经验,所以请随时联系我。
https://github.com/BBC-Esq/WhisperS2T-transcriber
...以及令人惊叹的...
https://github.com/shashikg/WhisperS2T
在约150颗星的时候,它默默无闻地飞行着...但它击败了Huggingface的“疯狂”(讨厌这个名字)的Whisper实现,后者有数千颗星。这只是说明了多少颗星星的典型Huggingface仓库与他们的产品质量完全无关,而是更多地受到市场营销和网络伙伴推荐的影响。...归功于应得的荣誉。试试whisperS2T并期待你的反馈!

9wbgstp7

9wbgstp72#

顺便说一下,我还没有时间更新我的whispers2t批处理仓库中的这个坏小子,所以请稍等。;-)

它允许您指定任务,选择任何ctranslate2量化,递归处理所有子目录,排除某些文件扩展名不被处理,更改束宽,批处理大小(感谢WhisperS2T),等等。

ssm49v7z

ssm49v7z3#

上一条帖子我承诺...但是这里是我对 WhisperS2T 的分析。我相信我的仓库使用传统的 "循环" 来处理使用 WhisperS2T...但是你也可以一次性发送一批信息直接到 ctranslate2 进行处理,这本质上是 WhisperS2T 运行的方式。然而,我选择了 "循环" 方法,因为如果你一次发送所有音频文件...如果有一个失败了,它们都会失败,你会得到零个转录。我发现如果我处理,比如说,500个音频文件,其中一个可能有损坏的数据,因此整个过程触发了一个错误...
这个问题应该已经修复了,不过根据这个讨论:
shashikg/WhisperS2T#50
无论如何,展开下面的内容以查看我对库的分析(不是最新的版本,但仍然):
我的个人总结

transcribe_with_vad (backends/init)

The transcribe_with_vad method in WhisperModel utilizes voice activity detection to transcribe audio files. It corrects the batch parameters for language codes, tasks, and initial prompts using fix_batch_param. The method then processes the audio files in batches through WhisperDataLoader, converting signals to mel spectrograms using LogMelSpectogram in self.preprocessor, and segments the audio based on voice activity. The transcription process is handled by generate_segment_batched, an abstract method to be implemented by subclasses. Progress is tracked using tqdm.

fix_batch_param (backends/init)

This function is utilized in the WhisperModel class for preparing batch parameters like language codes, tasks, and initial prompts, ensuring they match the number of audio files being processed.

Whisperdataloader (data)

The WhisperDataLoader class prepares audio data for transcription by segmenting and batching it. It relies on external configurations (SAMPLE_RATE, N_SAMPLES), utility functions (pad_or_trim, audio_batch_generator), and classes (BasicSegmenter, stitch_speech_segments from the same script, torch, numpy) for operation. It handles whether to use voice activity detection (speech_segmenter) or basic segmentation (basic_segmenter) based on input flags. The class segments audio files, optionally merges speech segments to respect maximum speech length, and creates batches that include processed audio signals, prompts, and metadata. The method data_collate_fn assembles these into a format suitable for model input, including padding or trimming audio to uniform length and organizing prompts. It supports dynamic time axis adjustment for batch processing and yields batches ready for transcription processing, tracking progress with tqdm.
LogMelSpectrogram (audio.py)

The LogMelSpectogram class, a subclass of nn.Module, is designed for converting audio signals into log-mel spectrogram features. It initializes with parameters for the mel-spectrogram calculation (n_mels, n_fft, hop_length, padding) and loads mel filter banks from a predefined file, registering them as a buffer. The class also contains an instance of TorchSTFT for performing short-time Fourier transforms (STFT). It provides a method get_seq_len to adjust sequence lengths based on the hop_length, and a forward method to apply padding (if required), compute the STFT, convert the power spectrogram to a mel scale using the loaded mel filters, apply logarithmic scaling, clip and scale the log-mel spectrograms. This process is essential for preparing audio data for deep learning models in speech processing tasks. It utilizes torch, numpy, F.pad from PyTorch's functional API, and custom configurations (N_MELS, N_FFT, HOP_LENGTH, BASE_PATH) for its operations.

WhisperModelCT2 (model.py)

WhisperModelCT2. This subclass provides a concrete implementation of the abstract method generate_segment_batched. Here's a brief overview of how WhisperModelCT2 handles the transcription process after reaching generate_segment_batched:
1.	Initialization and Configuration: The WhisperModelCT2 constructor initializes the model with configurations for ASR (Automatic Speech Recognition) options, loads the model and tokenizer, and sets up parameters for generating transcriptions.
2.	Model and Tokenizer Loading: It loads a translation or transcription model using ctranslate2 based on a path or model name. The tokenizer is also loaded from a specified file.
3.	Transcription Process (generate_segment_batched):
o	Converts the features (audio data processed into a suitable format like log-mel spectrograms) into a format expected by the ctranslate2 model.
o	Calls the ctranslate2 model's generate method with these features and the specified generation options. This step performs the actual ASR by generating text from the input audio features.
o	Decodes the output from the model into human-readable text using the loaded tokenizer.
o	Optionally calculates additional metrics like average log probability of the sequences and no-speech probability if specified in the generation options.
o	If word timestamps are required, it performs alignment of the generated text with the audio features to produce word-level timestamps. This involves calling align_words, which uses the ctranslate2 model's align function, and then assign_word_timings to assign timings to individual words.
4.	Word Timings and Alignment: If the option for word timestamps is enabled, WhisperModelCT2 uses the aligner_model (another instance of a ctranslate2 model) to align the words in the transcribed text with their corresponding positions in the audio. This process generates detailed timing information for each word in the transcription.
5.	Returning the Transcription: The method returns a list of dictionaries. Each dictionary contains the transcribed text and, depending on the configuration, may also include average log probabilities, no-speech probabilities, and word-level timing information.

model.transcribe_with_vad

model = whisper_s2t.load_model(model_identifier=model_identifier, backend='CTranslate2', device=self.device, compute_type=self.quantization, asr_options={'beam_size': self.beam_size}, cpu_threads=os.cpu_count())

MODEL TRANSCRIPTION PROCESS
----------------------------

model.transcribe_with_vad
│
├─ fix_batch_param (Adjustment of Parameters)
│  └─ Applies to: lang_codes, tasks, initial_prompts
│
└─ WhisperDataLoader (Data Preparation and Loading)
   │
   ├─ Conditional Branching: use_vad flag
   │  ├─ speech_segmenter (if VAD enabled)
   │  └─ basic_segmenter (if VAD disabled)
   │     └─ Optional: stitch_speech_segments (Merge if merge_chunks=True)
   │
   ├─ get_segmented_audio_signal (Audio Segmentation)
   │  ├─ tokenizer.sot_sequence (Start of Token Sequence)
   │  └─ tokenizer.encode (Encoding Initial Prompts)
   │
   ├─ data_collate_fn (Data Collation)
   │  ├─ pad_or_trim (Adjust Audio Signal Lengths)
   │  └─ External: Torch Operations (Stacking and Tensor Creation)
   │
   └─ Yields Batches to transcribe_with_vad (Transcription)
      │
      ├─ preprocessor (Feature Extraction)
      │  ├─ TorchSTFT (Spectral Transformation)
      │  └─ Mel Filter Application & Log Scaling
      │
      └─ generate_segment_batched (in WhisperModelCT2)
         │
         ├─ ctranslate2 Model's `generate` (ASR Generation)
         │  ├─ Decoding Output to Text
         │  ├─ Average Log Probability (Optional)
         │  └─ No-Speech Probability (Optional)
         │
         ├─ align_words (Word Alignment)
         │  ├─ aligner_model's `align` (Alignment Process)
         │  └─ assign_word_timings (Timestamp Assignment)
         │
         └─ Structured Transcription Output
            └─ Includes: Text, avg_logprob, no_speech_prob, word_timestamps
               └─ Creation of "out" variable (Result Packaging)

The difference between the original version of my transcriber and the new version that processes each file separately:

First Snippet: Processing Files Individually in a Loop
•	Sequential Processing: Each audio file is processed one at a time in a while loop, which continues until the file_queue is empty or enumeration_done is set. This approach allows for real-time updates and handling of files as they become available, which can be particularly useful in scenarios where files are being added to the queue dynamically.
Second Snippet: Batch Processing Multiple Files
•	Batch Processing: Processes a list of audio files (audio_files_str) in a single call to transcribe_with_vad. This approach is more efficient if all audio files are available at the start, as it can leverage batch processing optimizations.
esbemjvw

esbemjvw4#

谢谢你告诉我关于WhisperS2T的信息。我会稍后查看。目前我并没有使用faster-whisper,而是直接使用CTranslate2。希望批量处理可以用于加速生成,但现在它与仅使用标准循环相比并没有速度提升。

kqhtkvqz

kqhtkvqz5#

Whisper S2T基本上直接使用ctranslate 2。

相关问题