diff options
Diffstat (limited to 'gui.py')
| -rw-r--r-- | gui.py | 394 |
1 files changed, 235 insertions, 159 deletions
@@ -1,236 +1,312 @@ -import tkinter as tk -from tkinter import messagebox, ttk, simpledialog -from typing import Iterable, List, Tuple, Dict, Any +from typing import Iterable, List, Tuple, Dict, Any, cast import sounddevice as sd +from PySide6.QtGui import QFont +from PySide6.QtCore import Qt +from PySide6.QtWidgets import ( + QApplication, + QCheckBox, + QComboBox, + QDialog, + QDialogButtonBox, + QFormLayout, + QGroupBox, + QHBoxLayout, + QInputDialog, + QLabel, + QLineEdit, + QMessageBox, + QTabWidget, + QVBoxLayout, + QWidget, +) +class _SettingsDialog(QDialog): + def __init__( + self, + settings: Dict[str, Any], + input_devices: List[Tuple[int, Dict[str, Any]]], + default_settings: Dict[str, Any], + model_choices: Iterable[str], + device_choices: Iterable[str], + compute_choices: Iterable[str], + task_choices: Iterable[str], + ) -> None: + super().__init__() + self.setWindowTitle("Settings") + self.setModal(True) + self.setMinimumWidth(700) -def select_settings( - settings: Dict[str, Any], - input_devices: List[Tuple[int, Dict[str, Any]]], - default_settings: Dict[str, Any], - model_choices: Iterable[str], - device_choices: Iterable[str], - compute_choices: Iterable[str], - task_choices: Iterable[str], -) -> Dict[str, Any]: - if not input_devices: - raise RuntimeError("No audio input devices found.") + self.selected_settings: Dict[str, Any] = {} - def get_value(key: str, fallback: Any) -> Any: - return settings.get(key, default_settings.get(key, fallback)) + def get_value(key: str, fallback: Any) -> Any: + return settings.get(key, default_settings.get(key, fallback)) - root = tk.Tk() - root.title("Settings") - root.resizable(False, False) + self.device_names = [dev["name"] for _idx, dev in input_devices] - notebook = ttk.Notebook(root) - notebook.pack(fill="both", expand=True, padx=10, pady=(10, 0)) + root_layout = QVBoxLayout(self) - # ------------------------------------------------------------------ # - # Tab 1 – Whisper # - # ------------------------------------------------------------------ # - whisper_tab = ttk.Frame(notebook, padding=10) - whisper_tab.columnconfigure(1, weight=1) - notebook.add(whisper_tab, text="Whisper") + tabs = QTabWidget(self) + root_layout.addWidget(tabs) - def add_row(parent: ttk.Frame, row: int, label_text: str, widget: tk.Widget) -> None: - ttk.Label(parent, text=label_text).grid(row=row, column=0, sticky="w", pady=4, padx=(0, 12)) - widget.grid(row=row, column=1, sticky="ew", pady=4) + # Whisper tab + whisper_tab = QWidget(self) + whisper_tab_layout = QVBoxLayout(whisper_tab) - device_options = [ - f"[{idx}] {dev['name']} ({dev.get('max_input_channels', 0)} ch)" - for idx, dev in input_devices - ] - device_names = [dev["name"] for _idx, dev in input_devices] - device_combo = ttk.Combobox(whisper_tab, values=device_options, state="readonly", width=60) - default_device_name = get_value("audio_device_name", "") - if default_device_name in device_names: - device_combo.current(device_names.index(default_device_name)) - else: - device_combo.current(0) - add_row(whisper_tab, 0, "Audio input device:", device_combo) + whisper_layout = QFormLayout() + whisper_layout.setLabelAlignment(Qt.AlignmentFlag.AlignLeft) - model_var = tk.StringVar(value=get_value("model_name", "medium")) - model_combo = ttk.Combobox(whisper_tab, values=list(model_choices), textvariable=model_var) - model_combo.set(model_var.get()) - add_row(whisper_tab, 1, "Model:", model_combo) + device_options = [ + f"[{idx}] {dev['name']} ({dev.get('max_input_channels', 0)} ch)" + for idx, dev in input_devices + ] + self.device_combo = QComboBox(whisper_tab) + self.device_combo.addItems(device_options) + self.device_combo.setEditable(False) + default_device_name = get_value("audio_device_name", "") + if default_device_name in self.device_names: + self.device_combo.setCurrentIndex(self.device_names.index(default_device_name)) + else: + self.device_combo.setCurrentIndex(0) + whisper_layout.addRow(QLabel("Audio input device:"), self.device_combo) - device_type_var = tk.StringVar(value=get_value("device", "cpu")) - device_type_combo = ttk.Combobox( - whisper_tab, values=list(device_choices), textvariable=device_type_var, state="readonly" - ) - device_type_combo.set(device_type_var.get()) - add_row(whisper_tab, 2, "Compute device:", device_type_combo) + self.model_combo = QComboBox(whisper_tab) + self.model_combo.addItems(list(model_choices)) + self.model_combo.setEditable(True) + default_model = str(get_value("model_name", "medium")) + if default_model in [self.model_combo.itemText(i) for i in range(self.model_combo.count())]: + self.model_combo.setCurrentText(default_model) + else: + self.model_combo.setEditText(default_model) + whisper_layout.addRow(QLabel("Model:"), self.model_combo) - compute_type_var = tk.StringVar(value=get_value("compute_type", "int8")) - compute_type_combo = ttk.Combobox(whisper_tab, values=list(compute_choices), textvariable=compute_type_var) - compute_type_combo.set(compute_type_var.get()) - add_row(whisper_tab, 3, "Compute type:", compute_type_combo) + self.device_type_combo = QComboBox(whisper_tab) + self.device_type_combo.addItems(list(device_choices)) + self.device_type_combo.setEditable(False) + default_device_type = str(get_value("device", "cpu")) + if default_device_type in [self.device_type_combo.itemText(i) for i in range(self.device_type_combo.count())]: + self.device_type_combo.setCurrentText(default_device_type) + elif self.device_type_combo.count() > 0: + self.device_type_combo.setCurrentIndex(0) + whisper_layout.addRow(QLabel("Compute device:"), self.device_type_combo) - task_var = tk.StringVar(value=get_value("task", "translate")) - task_combo = ttk.Combobox(whisper_tab, values=list(task_choices), textvariable=task_var, state="readonly") - task_combo.set(task_var.get()) - add_row(whisper_tab, 4, "Task:", task_combo) + self.task_combo = QComboBox(whisper_tab) + self.task_combo.addItems(list(task_choices)) + self.task_combo.setEditable(False) + default_task = str(get_value("task", "translate")) + if default_task in [self.task_combo.itemText(i) for i in range(self.task_combo.count())]: + self.task_combo.setCurrentText(default_task) + elif self.task_combo.count() > 0: + self.task_combo.setCurrentIndex(0) + whisper_layout.addRow(QLabel("Task:"), self.task_combo) - beam_size_var = tk.StringVar(value=str(get_value("beam_size", 3))) - add_row(whisper_tab, 5, "Beam size:", ttk.Entry(whisper_tab, textvariable=beam_size_var, width=10)) + whisper_tab_layout.addLayout(whisper_layout) - language_var = tk.StringVar(value=get_value("language", "")) - add_row(whisper_tab, 6, "Language (optional):", ttk.Entry(whisper_tab, textvariable=language_var)) + whisper_advanced_group = QGroupBox("Advanced settings", whisper_tab) + whisper_advanced_layout = QFormLayout(whisper_advanced_group) + whisper_advanced_layout.setLabelAlignment(Qt.AlignmentFlag.AlignLeft) - context_seconds_var = tk.StringVar(value=str(get_value("context_seconds", 10))) - add_row(whisper_tab, 7, "Context seconds:", ttk.Entry(whisper_tab, textvariable=context_seconds_var, width=10)) + self.compute_type_combo = QComboBox(whisper_tab) + self.compute_type_combo.addItems(list(compute_choices)) + self.compute_type_combo.setEditable(True) + default_compute = str(get_value("compute_type", "int8")) + if default_compute in [self.compute_type_combo.itemText(i) for i in range(self.compute_type_combo.count())]: + self.compute_type_combo.setCurrentText(default_compute) + else: + self.compute_type_combo.setEditText(default_compute) + whisper_advanced_layout.addRow(QLabel("Compute type:"), self.compute_type_combo) - update_interval_var = tk.StringVar(value=str(get_value("update_interval_seconds", 2))) - add_row(whisper_tab, 8, "Update interval (s):", ttk.Entry(whisper_tab, textvariable=update_interval_var, width=10)) + self.beam_size_edit = QLineEdit(str(get_value("beam_size", 3)), whisper_tab) + whisper_advanced_layout.addRow(QLabel("Beam size:"), self.beam_size_edit) - # ------------------------------------------------------------------ # - # Tab 2 – Ollama # - # ------------------------------------------------------------------ # - ollama_tab = ttk.Frame(notebook, padding=10) - ollama_tab.columnconfigure(1, weight=1) - notebook.add(ollama_tab, text="Ollama") + self.language_edit = QLineEdit(str(get_value("language", "")), whisper_tab) + whisper_advanced_layout.addRow(QLabel("Language (optional):"), self.language_edit) - use_ollama_cleanup_var = tk.BooleanVar(value=get_value("use_ollama_cleanup", True)) - add_row(ollama_tab, 0, "LLM subtitle cleanup:", ttk.Checkbutton(ollama_tab, variable=use_ollama_cleanup_var)) + self.context_seconds_edit = QLineEdit(str(get_value("context_seconds", 10)), whisper_tab) + whisper_advanced_layout.addRow(QLabel("Context seconds:"), self.context_seconds_edit) - ollama_device_var = tk.StringVar(value=get_value("ollama_device", "CPU")) - ollama_device_combo = ttk.Combobox( - ollama_tab, values=["CPU", "GPU"], textvariable=ollama_device_var, state="readonly", width=10 - ) - ollama_device_combo.set(ollama_device_var.get()) - add_row(ollama_tab, 1, "Ollama compute:", ollama_device_combo) + self.update_interval_edit = QLineEdit(str(get_value("update_interval_seconds", 2)), whisper_tab) + whisper_advanced_layout.addRow(QLabel("Update interval (s):"), self.update_interval_edit) + + whisper_tab_layout.addWidget(whisper_advanced_group) + tabs.addTab(whisper_tab, "Whisper") + + # Ollama tab + ollama_tab = QWidget(self) + ollama_tab_layout = QVBoxLayout(ollama_tab) + + ollama_layout = QFormLayout() + ollama_layout.setLabelAlignment(Qt.AlignmentFlag.AlignLeft) + + self.use_ollama_cleanup_checkbox = QCheckBox(ollama_tab) + self.use_ollama_cleanup_checkbox.setChecked(bool(get_value("use_ollama_cleanup", True))) + ollama_layout.addRow(QLabel("LLM subtitle cleanup:"), self.use_ollama_cleanup_checkbox) + + self.ollama_device_combo = QComboBox(ollama_tab) + self.ollama_device_combo.addItems(["CPU", "GPU"]) + self.ollama_device_combo.setEditable(False) + default_ollama_device = str(get_value("ollama_device", "CPU")) + if default_ollama_device in [self.ollama_device_combo.itemText(i) for i in range(self.ollama_device_combo.count())]: + self.ollama_device_combo.setCurrentText(default_ollama_device) + ollama_layout.addRow(QLabel("Ollama compute:"), self.ollama_device_combo) + + ollama_tab_layout.addLayout(ollama_layout) + + ollama_advanced_group = QGroupBox("Advanced settings", ollama_tab) + ollama_advanced_layout = QFormLayout(ollama_advanced_group) + ollama_advanced_layout.setLabelAlignment(Qt.AlignmentFlag.AlignLeft) - ollama_context_var = tk.StringVar(value=str(get_value("ollama_context_window", 6))) - add_row(ollama_tab, 2, "Context window (segments):", ttk.Entry(ollama_tab, textvariable=ollama_context_var, width=10)) + self.ollama_context_edit = QLineEdit(str(get_value("ollama_context_window", 6)), ollama_tab) + ollama_advanced_layout.addRow(QLabel("Context window (segments):"), self.ollama_context_edit) - ollama_batch_var = tk.StringVar(value=str(get_value("ollama_raw_batch_size", 3))) - add_row(ollama_tab, 3, "Batch size (lines per LLM call):", ttk.Entry(ollama_tab, textvariable=ollama_batch_var, width=10)) + self.ollama_batch_edit = QLineEdit(str(get_value("ollama_raw_batch_size", 3)), ollama_tab) + ollama_advanced_layout.addRow(QLabel("Batch size (lines per LLM call):"), self.ollama_batch_edit) - # ------------------------------------------------------------------ # - # Buttons # - # ------------------------------------------------------------------ # - button_frame = ttk.Frame(root, padding=(10, 6, 10, 10)) - button_frame.pack(fill="x") + ollama_tab_layout.addWidget(ollama_advanced_group) + tabs.addTab(ollama_tab, "Ollama") - selected_settings: Dict[str, Any] = {} + button_layout = QHBoxLayout() + root_layout.addLayout(button_layout) + button_box = QDialogButtonBox( + QDialogButtonBox.StandardButton.Ok | QDialogButtonBox.StandardButton.Cancel, + self, + ) + button_box.accepted.connect(self.accept) + button_box.rejected.connect(self.reject) + button_layout.addWidget(button_box) - def on_ok() -> None: - nonlocal selected_settings + def _warn(self, title: str, text: str) -> None: + QMessageBox.warning(self, title, text) - selection = device_combo.current() + def accept(self) -> None: + selection = self.device_combo.currentIndex() if selection < 0: - messagebox.showwarning("Select a device", "Please select an audio input device.") + self._warn("Select a device", "Please select an audio input device.") return - model_name = model_var.get().strip() + model_name = self.model_combo.currentText().strip() if not model_name: - messagebox.showwarning("Model required", "Please select or enter a model name.") + self._warn("Model required", "Please select or enter a model name.") return try: - beam_size = int(beam_size_var.get().strip()) + beam_size = int(self.beam_size_edit.text().strip()) if beam_size <= 0: raise ValueError except ValueError: - messagebox.showwarning("Invalid beam size", "Beam size must be a positive integer.") + self._warn("Invalid beam size", "Beam size must be a positive integer.") return try: - context_seconds = float(context_seconds_var.get().strip()) + context_seconds = float(self.context_seconds_edit.text().strip()) if context_seconds <= 0: raise ValueError except ValueError: - messagebox.showwarning("Invalid context seconds", "Context seconds must be a positive number.") + self._warn("Invalid context seconds", "Context seconds must be a positive number.") return try: - update_interval_seconds = float(update_interval_var.get().strip()) + update_interval_seconds = float(self.update_interval_edit.text().strip()) if update_interval_seconds <= 0: raise ValueError except ValueError: - messagebox.showwarning("Invalid update interval", "Update interval must be a positive number.") + self._warn("Invalid update interval", "Update interval must be a positive number.") return try: - ollama_context_window = int(ollama_context_var.get().strip()) + ollama_context_window = int(self.ollama_context_edit.text().strip()) if ollama_context_window <= 0: raise ValueError except ValueError: - messagebox.showwarning("Invalid context window", "Context window must be a positive integer.") + self._warn("Invalid context window", "Context window must be a positive integer.") return try: - ollama_raw_batch_size = int(ollama_batch_var.get().strip()) + ollama_raw_batch_size = int(self.ollama_batch_edit.text().strip()) if ollama_raw_batch_size <= 0: raise ValueError except ValueError: - messagebox.showwarning("Invalid batch size", "Batch size must be a positive integer.") + self._warn("Invalid batch size", "Batch size must be a positive integer.") return - selected_settings = { - "audio_device_name": device_names[selection], + self.selected_settings = { + "audio_device_name": self.device_names[selection], "model_name": model_name, - "device": device_type_var.get().strip() or "cpu", - "compute_type": compute_type_var.get().strip() or "int8", - "task": task_var.get().strip() or "translate", + "device": self.device_type_combo.currentText().strip() or "cpu", + "compute_type": self.compute_type_combo.currentText().strip() or "int8", + "task": self.task_combo.currentText().strip() or "translate", "beam_size": beam_size, - "language": language_var.get().strip(), + "language": self.language_edit.text().strip(), "context_seconds": context_seconds, "update_interval_seconds": update_interval_seconds, - "use_ollama_cleanup": use_ollama_cleanup_var.get(), - "ollama_device": ollama_device_var.get(), + "use_ollama_cleanup": self.use_ollama_cleanup_checkbox.isChecked(), + "ollama_device": self.ollama_device_combo.currentText(), "ollama_context_window": ollama_context_window, "ollama_raw_batch_size": ollama_raw_batch_size, } - root.quit() + super().accept() - def on_cancel() -> None: - root.quit() - ok_button = ttk.Button(button_frame, text="OK", command=on_ok) - cancel_button = ttk.Button(button_frame, text="Cancel", command=on_cancel) - ok_button.pack(side="left", padx=(0, 6)) - cancel_button.pack(side="left") +def select_settings( + settings: Dict[str, Any], + input_devices: List[Tuple[int, Dict[str, Any]]], + default_settings: Dict[str, Any], + model_choices: Iterable[str], + device_choices: Iterable[str], + compute_choices: Iterable[str], + task_choices: Iterable[str], +) -> Dict[str, Any]: + if not input_devices: + raise RuntimeError("No audio input devices found.") + + app = QApplication.instance() + if app is None: + app = QApplication([]) + app = cast(QApplication, app) + app.setFont(QFont("Calibri", 12)) - root.protocol("WM_DELETE_WINDOW", on_cancel) - root.mainloop() - root.destroy() + dialog = _SettingsDialog( + settings=settings, + input_devices=input_devices, + default_settings=default_settings, + model_choices=model_choices, + device_choices=device_choices, + compute_choices=compute_choices, + task_choices=task_choices, + ) + result = dialog.exec() - if not selected_settings: + if result != int(QDialog.DialogCode.Accepted) or not dialog.selected_settings: raise SystemExit("No settings selected.") - return selected_settings + return dialog.selected_settings + + +def prompt_input_sample_rate(device_index: int, common_rates: Iterable[int]) -> int: + rates = list(common_rates) + while True: + prompt = ( + "Enter an input sample rate in Hz.\n" + f"Common values: {', '.join(str(r) for r in rates)}" + ) + raw, ok = QInputDialog.getText(None, "Select Sample Rate", prompt) + if not ok: + raise sd.PortAudioError("No supported input sample rate found for selected device.") + raw = raw.strip() + if not raw: + continue -def prompt_input_sample_rate(device_index: int, common_rates: Iterable[int] | None = None) -> int: - rates = list(common_rates) if common_rates is not None else [48000, 44100, 32000, 24000, 22050, 16000, 12000, 8000] - root = tk.Tk() - root.withdraw() - try: - while True: - prompt = ( - "Enter an input sample rate in Hz.\n" - f"Common values: {', '.join(str(r) for r in rates)}" + try: + rate = int(float(raw)) + except ValueError: + QMessageBox.warning(None, "Invalid value", "Sample rate must be a number.") + continue + + try: + sd.check_input_settings(device=device_index, channels=1, samplerate=rate, dtype="float32") + return rate + except sd.PortAudioError: + QMessageBox.warning( + None, + "Unsupported sample rate", + f"{rate} Hz is not supported by the selected device.", ) - raw = simpledialog.askstring("Select Sample Rate", prompt, parent=root) - if raw is None: - raise sd.PortAudioError("No supported input sample rate found for selected device.") - raw = raw.strip() - if not raw: - continue - try: - rate = int(float(raw)) - except ValueError: - messagebox.showwarning("Invalid value", "Sample rate must be a number.", parent=root) - continue - try: - sd.check_input_settings(device=device_index, channels=1, samplerate=rate, dtype="float32") - return rate - except sd.PortAudioError: - messagebox.showwarning( - "Unsupported sample rate", - f"{rate} Hz is not supported by the selected device.", - parent=root, - ) - finally: - root.destroy() |
