WERによる文字起こし精度チェック

はじめに

WER(Word Error Rate) は、音声認識システムや文字起こしサービスの精度を評価するための代表的な指標です。文字起こし結果と正解テキスト(ゴールドスタンダード)を比較し、誤りの割合を定量化します。

WER の定義と計算式

WER は次の3種類の誤りをすべて合計し、正解語数で割って算出します。

  • S(Substitutions):誤って別の単語に置き換えられた単語数
  • D(Deletions):正解ではあるが出力に欠落した単語数
  • I(Insertions):出力に余分に挿入された単語数
  • N:正解テキスト中の総単語数

たとえば、正解テキストが「今日はいい天気ですね」、文字起こし結果が「今日はいい天気です」の場合:

  • 置換(Substitution):“ですね”→“です” ⇒ S=1
  • 挿入(Insertion):なし ⇒ I=0
  • 欠落(Deletion):なし ⇒ D=0
  • 正解語数:4(「今日は」「いい」「天気」「ですね」) ⇒ N=4

なぜ WER を使うのか?

  • 客観的な比較が可能
    複数のモデルやサービスを同一のコーパスで比較し、どれが誤りが少ないか一目でわかる。
  • 定量的な改善指標
    モデル更新やパラメータ調整後に WER がどれだけ下がったかで効果を評価できる。
  • 多言語・多領域で応用可能
    日本語の文字起こしでも英語の音声認識でも同じ考え方で使える。

WER 評価時の注意点

  • 前処理の統一
    大文字小文字の統一、句読点や記号の扱い、助詞の分割など、正解と予測結果のテキスト正規化ルールを揃えること。
  • コーパスの多様性
    騒音環境、話者数、専門用語の有無など、評価対象のシナリオを広くカバーしたデータで測定すること。
  • リアルタイム係数(RTF)との併用
    精度だけでなく、「処理時間 ÷ 音声長」である RTF も同時に計測し、実運用時の性能を総合評価する。

ここで、RTF(Real Time Factor) は、音声処理システムの処理速度を評価する指標です。

  • RTF < 1.0:リアルタイム以上の速度
    (1秒の音声を1秒未満で処理できる)
  • RTF = 1.0:リアルタイムと同等の速度
  • RTF > 1.0:リアルタイム未満の速度
    (1秒の音声を1秒以上かけて処理する)

ポイント

  • 用途別の目安
    • ストリーミング処理:RTF ≤ 1 を目指す
    • バッチ処理:RTF が多少大きくても問題ない場合がある
  • 測定方法
    1. 処理開始前にタイムスタンプを取得
    2. 処理終了後にタイムスタンプを取得
    3. 「終了–開始」を音声長で割る

RTF を用いることで、同じモデルや環境であっても「処理の速さ」を定量的に比較・最適化できます。

WER を計算するツール例

  • jiwer(Python)
from jiwer import wer

ground_truth = "今日はいい天気ですね"
hypothesis  = "今日はいい天気です"
print(f"WER: {wer(ground_truth, hypothesis):.2%}")
  • SCTK(NIST 標準のスコアリングツール)

まとめ

  1. WER は音声認識の精度を示す基本指標
  2. S, D, I の総和を正解語数で割って算出
  3. 評価ルールやコーパスの設計を適切に行うことが重要

Assembly AIのWER評価用コード

LibriSpeech test-clean データセットのダウンロード

wget http://www.openslr.org/resources/12/test-clean.tar.gz
tar zxvf test-clean.tar.gz

実行

import os
from jiwer import wer
import requests
import re, time
import soundfile as sf

def call_assemblyai_api(audio_file):
    base_url = "https://api.assemblyai.com"
    headers = {"authorization": "APIキー"}

    with open(audio_file, "rb") as f:
        response = requests.post(base_url + "/v2/upload", headers=headers, data=f)

        if response.status_code != 200:
            print(f"Error: {response.status_code}, Response: {response.text}")
            response.raise_for_status()

        upload_json = response.json()
        upload_url = upload_json["upload_url"]

    data = {
        "audio_url": upload_url,
        "speech_model": "slam-1"
    }

    response = requests.post(base_url + "/v2/transcript", headers=headers, json=data)

    if response.status_code != 200:
        print(f"Error: {response.status_code}, Response: {response.text}")
        response.raise_for_status()

    transcript_json = response.json()
    transcript_id = transcript_json["id"]
    polling_endpoint = f"{base_url}/v2/transcript/{transcript_id}"

    while True:
        transcript = requests.get(polling_endpoint, headers=headers).json()
        if transcript["status"] == "completed":
            # print(f" \nFull Transcript: \n\n{transcript['text']}")
            return transcript["text"]
        elif transcript["status"] == "error":
            raise RuntimeError(f"Transcription failed: {transcript['error']}")
        else:
            time.sleep(3)


def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)



transcripts = {}
durations = {}
for root, _, files in os.walk("LibriSpeech/test-clean"):
    for fn in files:
        if fn.endswith(".txt"):
            with open(os.path.join(root, fn)) as f:
                for line in f:
                    utt, text = line.strip().split(" ", 1)
                    transcripts[f"{utt}.flac"] = text.lower()

                    path = os.path.join(root, f"{utt}.flac")
                    with sf.SoundFile(path) as f2:
                        # frame数 / サンプリングレート(Hz)
                        durations[f"{utt}.flac"] = len(f2) / f2.samplerate

print("start")

# テスト実行例
audio_file = "LibriSpeech/test-clean/61/70968/61-70968-0000.flac"
start = time.time()
hypothesis = call_assemblyai_api(audio_file)
end = time.time()
duration_process = end - start

gt = transcripts[os.path.basename(audio_file)]

duration_audio = durations[os.path.basename(audio_file)] 

gt_clean = remove_punctuation(gt.lower().strip())
hypothesis_clean = remove_punctuation(hypothesis.lower().strip())

print(gt_clean)
print(hypothesis_clean)
print(f"WER: {wer(gt_clean, hypothesis_clean):.2%}")

print(duration_process)
print(duration_audio)
print(f"RTF: {duration_process / duration_audio:.2f}")

関連記事

カテゴリー

アーカイブ

Lang »