aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--routes.py36
-rw-r--r--server.py22
2 files changed, 40 insertions, 18 deletions
diff --git a/routes.py b/routes.py
new file mode 100644
index 0000000..2a51766
--- /dev/null
+++ b/routes.py
@@ -0,0 +1,36 @@
+from typing import Callable, Generator, Optional
+
+from typing_extensions import Any
+from flask import Blueprint, Response, stream_with_context, Flask
+
+SSEGenerator = Callable[[], Generator[str, None, None]]
+
+
+def make_sse_blueprint(
+ event_stream_func: SSEGenerator,
+ blueprint_name: str = "sse_routes",
+ url_prefix: Optional[str] = None,
+) -> Blueprint:
+ bp = Blueprint(blueprint_name, __name__, url_prefix=url_prefix)
+
+ @bp.get("/events")
+ def events() -> Response:
+ headers = {
+ "Cache-Control": "no-cache",
+ "Connection": "keep-alive",
+ "Access-Control-Allow-Origin": "*",
+ }
+ return Response(stream_with_context(event_stream_func()), mimetype="text/event-stream", headers=headers)
+
+ @bp.get("/health")
+ def health() -> Response:
+ resp = Response("ok", mimetype="text/plain")
+ resp.headers["Access-Control-Allow-Origin"] = "*"
+ return resp
+
+ return bp
+
+
+def register_routes(app: Flask, event_stream_func: Any, url_prefix: Optional[str] = None) -> None:
+ bp = make_sse_blueprint(event_stream_func, url_prefix=url_prefix)
+ app.register_blueprint(bp)
diff --git a/server.py b/server.py
index f3038eb..35916b7 100644
--- a/server.py
+++ b/server.py
@@ -4,12 +4,13 @@ import json
import queue
import os
from typing import Any, Dict, Optional, Set, List, Iterator
-from flask import Flask, Response, stream_with_context
+from flask import Flask
from flask_cors import CORS
import numpy as np
import sounddevice as sd
from faster_whisper import WhisperModel
from gui import select_settings, prompt_input_sample_rate
+from routes import register_routes
TARGET_SAMPLE_RATE: int = 16000
CAPTURE_SAMPLE_RATE: int = 0
@@ -143,24 +144,9 @@ def event_stream() -> Iterator[str]:
clients.discard(client_queue)
-@app.get("/events")
-def events() -> Response:
- headers = {
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Access-Control-Allow-Origin": "*",
- }
- return Response(stream_with_context(event_stream()), mimetype="text/event-stream", headers=headers)
-
-
-@app.get("/health")
-def health() -> Response:
- response = Response("ok", mimetype="text/plain")
- response.headers["Access-Control-Allow-Origin"] = "*"
- return response
-
-
def start_subtitle_server() -> threading.Thread:
+ register_routes(app, event_stream)
+
thread = threading.Thread(
target=lambda: app.run(
host=SERVER_HOST,
send patches to the email below
yukais@pinapelz.com
include the subject [PATCH repo_name]
pinapelz.com
homepage