diff options
| author | Pinapelz <yukais@pinapelz.com> | 2026-04-26 20:55:53 -0700 |
|---|---|---|
| committer | Pinapelz <yukais@pinapelz.com> | 2026-04-26 21:02:29 -0700 |
| commit | 58c5449a35c51d9edea0fedacec964a9e5196e8f (patch) | |
| tree | d549f9004ccddd9e1829380144fee52ecbba4d73 /server.py | |
| parent | 1ff76ea759b966bc7683fcaff17314788687c950 (diff) | |
feat: use ollama to cleanup context window
Diffstat (limited to 'server.py')
| -rw-r--r-- | server.py | 190 |
1 files changed, 185 insertions, 5 deletions
@@ -3,14 +3,20 @@ import threading import json import queue import os -from typing import Any, Dict, Optional, Set, List, Iterator +from collections import Counter, deque +import re +from typing import Any, Deque, Dict, Optional, Set, List, Iterator from flask import Flask from flask_cors import CORS +import ollama as _ollama +from ollama import chat +from ollama import ChatResponse import numpy as np import sounddevice as sd from faster_whisper import WhisperModel from gui import select_settings, prompt_input_sample_rate from routes import register_routes +from config import _SYSTEM_PROMPT TARGET_SAMPLE_RATE: int = 16000 CAPTURE_SAMPLE_RATE: int = 0 @@ -20,6 +26,12 @@ PROCESS_INTERVAL_SECONDS: float = 2 SSE_EVENT_SUBTITLE: str = "subtitle" SSE_KEEPALIVE_SECONDS: int = 15 +USE_OLLAMA_CLEANUP: bool = True +OLLAMA_MODEL: str = "qwen2.5:7b-instruct" +OLLAMA_CONTEXT_WINDOW: int = 6 # number of recent cleaned segments kept as context +OLLAMA_OPTIONS: Dict[str, Any] = {"num_gpu": 1} +RAW_BATCH_SIZE: int = 2 # accumulate this many raw Whisper lines before calling the LLM + SETTINGS_PATH: str = os.path.join(os.path.dirname(__file__), "settings.json") DEFAULT_SETTINGS: Dict[str, Any] = { @@ -32,6 +44,10 @@ DEFAULT_SETTINGS: Dict[str, Any] = { "language": "", "context_seconds": 10, "update_interval_seconds": 2, + "use_ollama_cleanup": True, + "ollama_device": "GPU", + "ollama_context_window": 5, + "ollama_raw_batch_size": 2, } MODEL_CHOICES: List[str] = ["tiny", "base", "small", "medium", "large-v2", "large-v3", "distil-large-v3"] @@ -54,6 +70,12 @@ SERVER_PORT: int = 5000 app: Flask = Flask(__name__) CORS(app) +# OLLAMA stuff +llm_input_queue: queue.Queue = queue.Queue(maxsize=1) +subtitle_context: Deque[str] = deque(maxlen=OLLAMA_CONTEXT_WINDOW) # sliding window context +subtitle_context_lock: threading.Lock = threading.Lock() +_raw_batch: List[str] = [] +_raw_batch_lock: threading.Lock = threading.Lock() def resample_audio(audio_np: np.ndarray, src_rate: int, dst_rate: int) -> np.ndarray: if src_rate == dst_rate: @@ -90,18 +112,163 @@ def save_settings(settings: Dict[str, Any]) -> None: except OSError as exc: print(f"Failed to save settings: {exc}") +def cleanup_subtitle_with_ollama(raw_text: str, context: List[str]) -> Optional[str]: + if context: + context_block = "\n".join(f"- {seg}" for seg in context) + else: + context_block = "(none yet)" + + user_message = ( + f"ALREADY SHOWN:\n{context_block}\n\n" + "RAW INPUT (multiple consecutive transcriptions of the same rolling window — " + f"deduplicate and extract only the genuinely new spoken content as one subtitle):\n{raw_text}" + ) + + try: + response: ChatResponse = chat( + model=OLLAMA_MODEL, + messages=[ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": user_message}, + ], + options=OLLAMA_OPTIONS, + ) + return response.message.content.strip() + except Exception as exc: + print(f"⚠️ OLLAMA cleanup error: {exc}") + return None + + +def ensure_ollama_ready() -> None: + try: + local = _ollama.list() + except Exception as exc: + raise RuntimeError( + f"Cannot reach Ollama — is the server running? ({exc})" + ) from exc + model_names: List[str] = [m.model for m in local.models] + if not any(name.startswith(OLLAMA_MODEL) for name in model_names): + print(f" '{OLLAMA_MODEL}' not found locally — pulling (this may take a while) ...") + try: + _ollama.pull(OLLAMA_MODEL) + print(" Pull complete.") + except Exception as exc: + raise RuntimeError(f"Failed to pull model '{OLLAMA_MODEL}': {exc}") from exc + else: + print(f" Model found locally.") + print(" Warming up model, almost done ...") + try: + chat( + model=OLLAMA_MODEL, + messages=[{"role": "user", "content": "Ready?"}], + options=OLLAMA_OPTIONS, + ) + print(" ✅ Ollama is ready.") + except Exception as exc: + raise RuntimeError(f"Ollama warm-up failed: {exc}") from exc + +_LLM_EMPTY_SENTINELS: frozenset = frozenset({ + "empty string", "empty", "(empty)", "[empty]", + "(empty string)", "[empty string]", "(none)", "none", "n/a", +}) + + +def normalize_llm_output(text: str) -> str: + if text.strip().lower().rstrip(".") in _LLM_EMPTY_SENTINELS: + return "" + return text + + +def is_hallucination(text: str) -> bool: + words = text.split() + if not words: + return False + max_expected = int(BUFFER_SECONDS * 4.5) + if len(words) > max_expected: + print(f"🔴 Hallucination (too long: {len(words)} words > {max_expected}): {text[:60]!r}") + return True + clean = [re.sub(r"[^\w']+", "", w).lower() for w in words] + clean = [w for w in clean if w] + for n in (2, 3): + if len(clean) < n * 3: + continue + ngrams = [" ".join(clean[i : i + n]) for i in range(len(clean) - n + 1)] + top, count = Counter(ngrams).most_common(1)[0] + if count >= 3: + print(f"🔴 Hallucination (\'{top}\' x{count}): {text[:60]!r}") + return True + top, count = Counter(clean).most_common(1)[0] + if count >= 4 and count / len(clean) > 0.40: + print(f"🔴 Hallucination (\'{top}\' x{count}, {count/len(clean):.0%}): {text[:60]!r}") + return True + return False + + +def llm_processing_loop() -> None: + print(f"LLM cleanup thread started (model={OLLAMA_MODEL})") + while True: + try: + raw_text: str = llm_input_queue.get(timeout=1) + except queue.Empty: + continue + + with subtitle_context_lock: + context = list(subtitle_context) + + cleaned: Optional[str] = cleanup_subtitle_with_ollama(raw_text, context) + + if cleaned is None: + cleaned = raw_text + else: + cleaned = normalize_llm_output(cleaned) + + if cleaned: + with subtitle_context_lock: + subtitle_context.append(cleaned) + print(f"🔵 (cleaned) {cleaned}") + broadcast_subtitle(cleaned) + else: + print("🟡 (LLM: no new content)") + def run_whisper(audio_np: np.ndarray) -> str: transcribe_kwargs: Dict[str, Any] = {"task": WHISPER_TASK, "beam_size": WHISPER_BEAM_SIZE} if WHISPER_LANGUAGE: transcribe_kwargs["language"] = WHISPER_LANGUAGE - # model is expected to be initialized in main() assert model is not None, "Whisper model is not initialized" segments, _info = model.transcribe(audio_np, **transcribe_kwargs) text = " ".join(seg.text for seg in segments).strip() - if text: - print("🟢", text) + if not text: + return text + + print(f"🟢 (raw) {text}") + + if is_hallucination(text): + return text + + if USE_OLLAMA_CLEANUP: + with _raw_batch_lock: + _raw_batch.append(text) + if len(_raw_batch) >= RAW_BATCH_SIZE: + batch_text = "\n".join(_raw_batch) + _raw_batch.clear() + else: + batch_text = None + if batch_text is not None: + try: + llm_input_queue.put_nowait(batch_text) + except queue.Full: + try: + llm_input_queue.get_nowait() + except queue.Empty: + pass + try: + llm_input_queue.put_nowait(batch_text) + except queue.Full: + pass + else: broadcast_subtitle(text) + return text @@ -225,7 +392,8 @@ def select_input_sample_rate(device_index: int, preferred_rate: int) -> int: def main() -> None: global CAPTURE_SAMPLE_RATE, MAX_SAMPLES, model, WHISPER_TASK, WHISPER_BEAM_SIZE, WHISPER_LANGUAGE - global BUFFER_SECONDS, PROCESS_INTERVAL_SECONDS + global BUFFER_SECONDS, PROCESS_INTERVAL_SECONDS, USE_OLLAMA_CLEANUP + global OLLAMA_CONTEXT_WINDOW, RAW_BATCH_SIZE, subtitle_context start_subtitle_server() settings: Dict[str, Any] = load_settings() @@ -242,6 +410,16 @@ def main() -> None: ) save_settings(settings) + USE_OLLAMA_CLEANUP = bool(settings.get("use_ollama_cleanup", True)) + OLLAMA_OPTIONS["num_gpu"] = 0 if settings.get("ollama_device", "CPU").upper() == "CPU" else 1 + OLLAMA_CONTEXT_WINDOW = int(settings.get("ollama_context_window", 6)) + subtitle_context = deque(maxlen=OLLAMA_CONTEXT_WINDOW) + RAW_BATCH_SIZE = int(settings.get("ollama_raw_batch_size", 3)) + if USE_OLLAMA_CLEANUP: + ensure_ollama_ready() + llm_thread = threading.Thread(target=llm_processing_loop, daemon=True) + llm_thread.start() + device_name: str = settings.get("audio_device_name", "") matched_index: Optional[int] = None for idx, dev in enumerate(devices): @@ -277,6 +455,8 @@ def main() -> None: print(f"Model: {model_name} | task={WHISPER_TASK} | beam_size={WHISPER_BEAM_SIZE}") print(f"Compute: device={whisper_device} | compute_type={compute_type}") print(f"Capture sample rate: {CAPTURE_SAMPLE_RATE} Hz (resampling to {TARGET_SAMPLE_RATE} Hz)") + print(f"Ollama cleanup: {'enabled' if USE_OLLAMA_CLEANUP else 'disabled'} (model={OLLAMA_MODEL})") + processing_thread = threading.Thread(target=processing_loop, daemon=True) processing_thread.start() with sd.InputStream( |
