402 lines
16 KiB
Python
402 lines
16 KiB
Python
import asyncio
|
|
import json
|
|
import secrets
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
|
|
import bcrypt
|
|
from aiohttp import web
|
|
|
|
from config import AppConfig
|
|
from state import MessageState
|
|
|
|
DASHBOARD_DIR = Path(__file__).parent / "dashboard"
|
|
SESSION_COOKIE = "kh_session"
|
|
|
|
|
|
# ── Session store ─────────────────────────────────────────────────────────────
|
|
|
|
@dataclass
|
|
class _Session:
|
|
last_active: float = field(default_factory=time.time)
|
|
|
|
|
|
class SessionStore:
|
|
def __init__(self, timeout_hours: int):
|
|
self._sessions: dict[str, _Session] = {}
|
|
self._lock = threading.Lock()
|
|
self._timeout = timeout_hours * 3600.0
|
|
|
|
def create(self) -> str:
|
|
sid = str(uuid.uuid4())
|
|
with self._lock:
|
|
self._sessions[sid] = _Session()
|
|
return sid
|
|
|
|
def validate(self, sid: str) -> bool:
|
|
with self._lock:
|
|
s = self._sessions.get(sid)
|
|
if s is None:
|
|
return False
|
|
now = time.time()
|
|
if now - s.last_active > self._timeout:
|
|
del self._sessions[sid]
|
|
return False
|
|
s.last_active = now
|
|
return True
|
|
|
|
def delete(self, sid: str) -> None:
|
|
with self._lock:
|
|
self._sessions.pop(sid, None)
|
|
|
|
|
|
# ── Rate limiter ──────────────────────────────────────────────────────────────
|
|
|
|
class RateLimiter:
|
|
def __init__(self, max_per_minute: int):
|
|
self._max = max_per_minute
|
|
self._data: dict[str, list[float]] = {}
|
|
self._lock = threading.Lock()
|
|
|
|
def is_allowed(self, ip: str) -> bool:
|
|
now = time.time()
|
|
cutoff = now - 60.0
|
|
with self._lock:
|
|
ts = [t for t in self._data.get(ip, []) if t > cutoff]
|
|
if len(ts) >= self._max:
|
|
self._data[ip] = ts
|
|
return False
|
|
ts.append(now)
|
|
self._data[ip] = ts
|
|
return True
|
|
|
|
|
|
# ── Inline HTML pages ─────────────────────────────────────────────────────────
|
|
|
|
_PAGE_STYLE = """
|
|
* { box-sizing: border-box; margin: 0; padding: 0; }
|
|
body {
|
|
background: #111; color: #f0f0f0;
|
|
font-family: system-ui, -apple-system, sans-serif;
|
|
display: flex; justify-content: center; align-items: center;
|
|
min-height: 100vh; padding: 1rem;
|
|
}
|
|
.card {
|
|
background: #1a1a1a; border: 1px solid #333; border-radius: 10px;
|
|
padding: 2rem; width: 100%; max-width: 420px;
|
|
}
|
|
h1 { font-size: 1.5rem; margin-bottom: 0.5rem; text-align: center; }
|
|
.subtitle { color: #888; text-align: center; margin-bottom: 1.75rem; font-size: 0.9rem; }
|
|
label { display: block; margin-bottom: 0.4rem; color: #aaa; font-size: 0.9rem; }
|
|
input[type=password] {
|
|
width: 100%; padding: 0.75rem; border-radius: 6px;
|
|
border: 1px solid #444; background: #242424; color: #f0f0f0;
|
|
font-size: 1.05rem; margin-bottom: 1.1rem;
|
|
}
|
|
input[type=password]:focus { outline: none; border-color: #2563eb; }
|
|
.btn {
|
|
width: 100%; padding: 0.875rem; border: none; border-radius: 8px;
|
|
font-size: 1.05rem; font-weight: 600; cursor: pointer;
|
|
}
|
|
.btn-blue { background: #2563eb; color: #fff; }
|
|
.btn-blue:hover { background: #1d4ed8; }
|
|
.btn-green { background: #16a34a; color: #fff; }
|
|
.btn-green:hover { background: #15803d; }
|
|
.error { color: #f87171; margin-bottom: 1rem; text-align: center; font-size: 0.95rem; }
|
|
"""
|
|
|
|
|
|
def _login_page(error: str = "") -> str:
|
|
err = f'<p class="error">{error}</p>' if error else ""
|
|
return f"""<!doctype html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="utf-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
<title>KH Clock — Sign In</title>
|
|
<style>{_PAGE_STYLE}</style>
|
|
</head>
|
|
<body>
|
|
<div class="card">
|
|
<h1>KH Clock</h1>
|
|
<p class="subtitle">Sign in to access the dashboard.</p>
|
|
{err}
|
|
<form method="post" action="/login">
|
|
<label for="pw">Password</label>
|
|
<input type="password" id="pw" name="password" autofocus autocomplete="current-password">
|
|
<button type="submit" class="btn btn-blue">Sign In</button>
|
|
</form>
|
|
</div>
|
|
</body>
|
|
</html>"""
|
|
|
|
|
|
def _setup_page(error: str = "") -> str:
|
|
err = f'<p class="error">{error}</p>' if error else ""
|
|
return f"""<!doctype html>
|
|
<html lang="en">
|
|
<head>
|
|
<meta charset="utf-8">
|
|
<meta name="viewport" content="width=device-width, initial-scale=1">
|
|
<title>KH Clock — First Run Setup</title>
|
|
<style>{_PAGE_STYLE}</style>
|
|
</head>
|
|
<body>
|
|
<div class="card">
|
|
<h1>First Run Setup</h1>
|
|
<p class="subtitle">Create a password to protect the dashboard.</p>
|
|
{err}
|
|
<form method="post" action="/setup">
|
|
<label for="pw">Password (6+ characters)</label>
|
|
<input type="password" id="pw" name="password" autofocus minlength="6" autocomplete="new-password">
|
|
<label for="pw2">Confirm Password</label>
|
|
<input type="password" id="pw2" name="confirm" autocomplete="new-password">
|
|
<button type="submit" class="btn btn-green">Set Password & Continue</button>
|
|
</form>
|
|
</div>
|
|
</body>
|
|
</html>"""
|
|
|
|
|
|
# ── Message body parsing ──────────────────────────────────────────────────────
|
|
|
|
def _parse_message_body(body: dict, default_duration: float) -> tuple[str, float | None]:
|
|
"""
|
|
Returns (text, duration_seconds).
|
|
duration=None means persistent.
|
|
"""
|
|
text = body.get("text", "").strip()
|
|
if body.get("persist") or body.get("duration") == 0:
|
|
duration: float | None = None
|
|
elif "duration" in body:
|
|
duration = float(body["duration"])
|
|
else:
|
|
duration = default_duration
|
|
return text, duration
|
|
|
|
|
|
# ── Server ────────────────────────────────────────────────────────────────────
|
|
|
|
class ClockServer:
|
|
def __init__(self, state: MessageState, config: AppConfig):
|
|
self.state = state
|
|
self.config = config
|
|
self.sessions = SessionStore(config.session_timeout_hours)
|
|
self.rate_limiter = RateLimiter(config.rate_limit)
|
|
self.app = self._build_app()
|
|
|
|
def _build_app(self) -> web.Application:
|
|
app = web.Application()
|
|
|
|
# Static dashboard assets (CSS, JS) — no auth required
|
|
app.router.add_static("/static", DASHBOARD_DIR, show_index=False)
|
|
|
|
# First-run setup
|
|
app.router.add_get("/setup", self._handle_setup_get)
|
|
app.router.add_post("/setup", self._handle_setup_post)
|
|
|
|
# Authentication
|
|
app.router.add_get("/login", self._handle_login_get)
|
|
app.router.add_post("/login", self._handle_login_post)
|
|
app.router.add_post("/logout", self._handle_logout)
|
|
|
|
# Dashboard (session-protected)
|
|
app.router.add_get("/", self._handle_root)
|
|
app.router.add_post("/dashboard/message", self._handle_dashboard_set)
|
|
app.router.add_delete("/dashboard/message", self._handle_dashboard_clear)
|
|
app.router.add_get("/dashboard/status", self._handle_dashboard_status)
|
|
|
|
# API (bearer token + rate limit)
|
|
app.router.add_post("/api/message", self._handle_api_set)
|
|
app.router.add_delete("/api/message", self._handle_api_clear)
|
|
app.router.add_get("/api/status", self._handle_api_status)
|
|
|
|
return app
|
|
|
|
# ── Auth helpers ──────────────────────────────────────────────────────────
|
|
|
|
def _setup_needed(self) -> bool:
|
|
return not self.config.password_hash
|
|
|
|
def _check_session(self, request: web.Request) -> bool:
|
|
sid = request.cookies.get(SESSION_COOKIE)
|
|
return bool(sid and self.sessions.validate(sid))
|
|
|
|
def _session_redirect(self) -> web.Response:
|
|
"""Redirect to setup or login depending on configuration state."""
|
|
if self._setup_needed():
|
|
return web.HTTPFound("/setup")
|
|
return web.HTTPFound("/login")
|
|
|
|
def _set_session_cookie(self, response: web.Response, sid: str) -> None:
|
|
response.set_cookie(
|
|
SESSION_COOKIE,
|
|
sid,
|
|
httponly=True,
|
|
samesite="Lax",
|
|
max_age=self.config.session_timeout_hours * 3600,
|
|
)
|
|
|
|
def _require_bearer(self, request: web.Request) -> web.Response | None:
|
|
"""Return 401 if bearer token is missing or wrong, else None."""
|
|
auth = request.headers.get("Authorization", "")
|
|
if not auth.startswith("Bearer "):
|
|
return web.Response(status=401, text="Unauthorized")
|
|
token = auth[7:].strip()
|
|
if not secrets.compare_digest(token, self.config.api_token):
|
|
return web.Response(status=401, text="Unauthorized")
|
|
return None
|
|
|
|
def _require_rate_limit(self, request: web.Request) -> web.Response | None:
|
|
"""Return 429 if this IP has exceeded the rate limit, else None."""
|
|
ip = request.remote or "unknown"
|
|
if not self.rate_limiter.is_allowed(ip):
|
|
return web.Response(status=429, text="Too Many Requests")
|
|
return None
|
|
|
|
# ── Setup handlers ────────────────────────────────────────────────────────
|
|
|
|
async def _handle_setup_get(self, request: web.Request) -> web.Response:
|
|
if not self._setup_needed():
|
|
return web.HTTPFound("/")
|
|
return web.Response(content_type="text/html", text=_setup_page())
|
|
|
|
async def _handle_setup_post(self, request: web.Request) -> web.Response:
|
|
if not self._setup_needed():
|
|
return web.HTTPFound("/")
|
|
|
|
data = await request.post()
|
|
password = data.get("password", "")
|
|
confirm = data.get("confirm", "")
|
|
|
|
if len(password) < 6:
|
|
return web.Response(
|
|
content_type="text/html",
|
|
text=_setup_page("Password must be at least 6 characters."),
|
|
)
|
|
if password != confirm:
|
|
return web.Response(
|
|
content_type="text/html",
|
|
text=_setup_page("Passwords do not match."),
|
|
)
|
|
|
|
hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
self.config.save_password_hash(hashed)
|
|
return web.HTTPFound("/login")
|
|
|
|
# ── Login / logout handlers ───────────────────────────────────────────────
|
|
|
|
async def _handle_login_get(self, request: web.Request) -> web.Response:
|
|
if self._setup_needed():
|
|
return web.HTTPFound("/setup")
|
|
return web.Response(content_type="text/html", text=_login_page())
|
|
|
|
async def _handle_login_post(self, request: web.Request) -> web.Response:
|
|
if self._setup_needed():
|
|
return web.HTTPFound("/setup")
|
|
|
|
data = await request.post()
|
|
password = data.get("password", "").encode()
|
|
|
|
if not bcrypt.checkpw(password, self.config.password_hash.encode()):
|
|
return web.Response(
|
|
content_type="text/html",
|
|
text=_login_page("Incorrect password."),
|
|
)
|
|
|
|
sid = self.sessions.create()
|
|
response = web.HTTPFound("/")
|
|
self._set_session_cookie(response, sid)
|
|
return response
|
|
|
|
async def _handle_logout(self, request: web.Request) -> web.Response:
|
|
sid = request.cookies.get(SESSION_COOKIE)
|
|
if sid:
|
|
self.sessions.delete(sid)
|
|
response = web.HTTPFound("/login")
|
|
response.del_cookie(SESSION_COOKIE)
|
|
return response
|
|
|
|
# ── Dashboard handlers ────────────────────────────────────────────────────
|
|
|
|
async def _handle_root(self, request: web.Request) -> web.Response:
|
|
if not self._check_session(request):
|
|
return self._session_redirect()
|
|
return web.FileResponse(DASHBOARD_DIR / "index.html")
|
|
|
|
async def _handle_dashboard_set(self, request: web.Request) -> web.Response:
|
|
if not self._check_session(request):
|
|
return web.Response(status=401, text="Unauthorized")
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
return web.Response(status=400, text="Invalid JSON")
|
|
text, duration = _parse_message_body(body, self.config.default_duration)
|
|
if not text:
|
|
return web.Response(status=400, text="Missing or empty 'text' field")
|
|
self.state.set(text, duration)
|
|
return web.Response(content_type="application/json", text=json.dumps({"ok": True}))
|
|
|
|
async def _handle_dashboard_clear(self, request: web.Request) -> web.Response:
|
|
if not self._check_session(request):
|
|
return web.Response(status=401, text="Unauthorized")
|
|
self.state.clear()
|
|
return web.Response(content_type="application/json", text=json.dumps({"ok": True}))
|
|
|
|
async def _handle_dashboard_status(self, request: web.Request) -> web.Response:
|
|
if not self._check_session(request):
|
|
return web.Response(status=401, text="Unauthorized")
|
|
return web.Response(
|
|
content_type="application/json",
|
|
text=json.dumps(self.state.to_dict()),
|
|
)
|
|
|
|
# ── API handlers ──────────────────────────────────────────────────────────
|
|
|
|
async def _handle_api_set(self, request: web.Request) -> web.Response:
|
|
if err := self._require_bearer(request):
|
|
return err
|
|
if err := self._require_rate_limit(request):
|
|
return err
|
|
try:
|
|
body = await request.json()
|
|
except Exception:
|
|
return web.Response(status=400, text="Invalid JSON")
|
|
text, duration = _parse_message_body(body, self.config.default_duration)
|
|
if not text:
|
|
return web.Response(status=400, text="Missing or empty 'text' field")
|
|
self.state.set(text, duration)
|
|
return web.Response(content_type="application/json", text=json.dumps({"ok": True}))
|
|
|
|
async def _handle_api_clear(self, request: web.Request) -> web.Response:
|
|
if err := self._require_bearer(request):
|
|
return err
|
|
if err := self._require_rate_limit(request):
|
|
return err
|
|
self.state.clear()
|
|
return web.Response(content_type="application/json", text=json.dumps({"ok": True}))
|
|
|
|
async def _handle_api_status(self, request: web.Request) -> web.Response:
|
|
if err := self._require_bearer(request):
|
|
return err
|
|
return web.Response(
|
|
content_type="application/json",
|
|
text=json.dumps(self.state.to_dict()),
|
|
)
|
|
|
|
|
|
# ── Entry point ───────────────────────────────────────────────────────────────
|
|
|
|
async def run_server(state: MessageState, config: AppConfig) -> None:
|
|
server = ClockServer(state, config)
|
|
runner = web.AppRunner(server.app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "0.0.0.0", config.port)
|
|
await site.start()
|
|
print(f"[KH-clock] Dashboard: http://0.0.0.0:{config.port}")
|
|
print(f"[KH-clock] API: http://0.0.0.0:{config.port}/api/")
|
|
await asyncio.Event().wait() # run forever
|