File size: 3,870 Bytes
7ba2f95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7271127
7ba2f95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7271127
7ba2f95
 
 
 
 
 
 
 
 
7271127
7ba2f95
 
 
 
 
 
 
 
 
 
 
 
 
7271127
 
7ba2f95
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Optional Supabase persistence for predictions.

The API works fine without credentials — all functions degrade
gracefully when ``SUPABASE_URL`` / ``SUPABASE_KEY`` are missing.
"""

from __future__ import annotations

import os
from functools import lru_cache
from typing import Any

from src.utils.logger import get_logger

logger = get_logger(__name__)

try:
    from supabase import Client, create_client
except ImportError:  # pragma: no cover - dep listed in pyproject
    Client = None  # type: ignore[assignment,misc]
    create_client = None  # type: ignore[assignment]


_TABLE = "predictions"


@lru_cache(maxsize=1)
def get_client() -> "Client | None":
    """Return a cached Supabase client, or ``None`` if not configured."""
    url = os.getenv("SUPABASE_URL", "").strip()
    key = os.getenv("SUPABASE_KEY", "").strip()
    if not url or not key:
        return None
    if create_client is None:
        logger.warning("supabase package not available; persistence disabled")
        return None
    try:
        client = create_client(url, key)
        logger.info("Supabase client initialized")
        return client
    except Exception as exc:  # pragma: no cover - network/config errors
        logger.warning("Failed to initialize Supabase client: %s", exc)
        return None


def save_prediction(
    text: str,
    result: Any,
    source: str,
    video_id: str | None = None,
    video_url: str | None = None,
    threshold: float | None = None,
    latency_ms: float | None = None,
    author: str | None = None,
) -> None:
    """Persist a single prediction, silently no-op when DB is not configured.

    ``result`` may be a Pydantic ``PredictResponse`` or a dict with the same
    fields (``probability``, ``is_toxic``, ``labels``, ``model_used``,
    ``latency_ms``).
    """
    client = get_client()
    if client is None:
        return

    try:
        if hasattr(result, "model_dump"):
            data = result.model_dump()
        elif isinstance(result, dict):
            data = result
        else:
            data = {
                "probability": getattr(result, "probability", None),
                "is_toxic": getattr(result, "is_toxic", None),
                "labels": getattr(result, "labels", []),
                "model_used": getattr(result, "model_used", ""),
                "latency_ms": getattr(result, "latency_ms", None),
            }

        row = {
            "text": text,
            "video_id": video_id,
            "video_url": video_url,
            "probability": data.get("probability"),
            "is_toxic": data.get("is_toxic"),
            "labels": data.get("labels", []) or [],
            "model_used": data.get("model_used", ""),
            "threshold": threshold,
            "latency_ms": latency_ms if latency_ms is not None else data.get("latency_ms"),
            "source": source,
            "author": author,
        }
        client.table(_TABLE).insert(row).execute()
    except Exception as exc:
        logger.warning("save_prediction failed (non-critical): %s", exc)


def list_predictions(
    video_id: str | None = None,
    limit: int = 50,
    source: str | None = None,
) -> list[dict]:
    """Return latest predictions ordered by ``created_at`` desc.

    Returns ``[]`` when the client is not configured.
    """
    client = get_client()
    if client is None:
        return []

    try:
        query = client.table(_TABLE).select("*").order("created_at", desc=True)
        if video_id:
            query = query.eq("video_id", video_id)
        if source:
            query = query.eq("source", source)
        query = query.limit(max(1, min(limit, 200)))
        response = query.execute()
        return list(getattr(response, "data", []) or [])
    except Exception as exc:
        logger.warning("list_predictions failed: %s", exc)
        return []