diff --git a/src/groundlight/__init__.py b/src/groundlight/__init__.py index 805fdd33..baf66fd3 100644 --- a/src/groundlight/__init__.py +++ b/src/groundlight/__init__.py @@ -7,7 +7,7 @@ # Imports from our code from .client import Groundlight -from .client import GroundlightClientError, ApiTokenError, EdgeNotAvailableError, NotFoundError +from .client import GroundlightClientError, ApiTokenError, EdgeNotAvailableError, NotFoundError, VLMVerificationResult from .experimental_api import ExperimentalApi from .binary_labels import Label from .version import get_version diff --git a/src/groundlight/cli.py b/src/groundlight/cli.py index 82fa5fa9..874037e0 100644 --- a/src/groundlight/cli.py +++ b/src/groundlight/cli.py @@ -180,6 +180,7 @@ def wrapper(*args, **kwargs): "Image Queries", "ML Pipelines & Priming", "Notes", + "VLM Verification", "Utilities", ] @@ -232,6 +233,8 @@ def wrapper(*args, **kwargs): "create_priming_group": "ML Pipelines & Priming", "get_priming_group": "ML Pipelines & Priming", "delete_priming_group": "ML Pipelines & Priming", + # VLM Verification + "ask_vlm": "VLM Verification", # Utilities "edge_base_url": "Utilities", "get_raw_headers": "Utilities", diff --git a/src/groundlight/client.py b/src/groundlight/client.py index edcb8771..1d5946e3 100644 --- a/src/groundlight/client.py +++ b/src/groundlight/client.py @@ -1,12 +1,14 @@ # pylint: disable=too-many-lines import logging import os +import re import time import warnings from functools import partial from io import BufferedReader, BytesIO from typing import Any, Callable, List, Optional, Tuple, Union +import requests from groundlight_openapi_client import Configuration from groundlight_openapi_client.api.detector_groups_api import DetectorGroupsApi from groundlight_openapi_client.api.detectors_api import DetectorsApi @@ -37,6 +39,7 @@ PaginatedDetectorList, PaginatedImageQueryList, ) +from pydantic import BaseModel from urllib3.exceptions import InsecureRequestWarning from urllib3.util.retry import Retry @@ -73,6 +76,24 @@ class EdgeNotAvailableError(GroundlightClientError): """Raised when an edge-only method is called against a non-edge endpoint.""" +MAX_VLM_MEDIA_ITEMS = 8 + + +class VLMVerificationResult(BaseModel): + """Result of a VLM-based alert verification via the Groundlight cloud API.""" + + id: str + query: str + model_id: str + verdict: str # "YES" | "NO" | "UNSURE" + confidence: float # 0.0–1.0 + reasoning: str + created_at: str + input_tokens: Optional[int] = None + output_tokens: Optional[int] = None + total_cost_usd: Optional[float] = None + + class Groundlight: # pylint: disable=too-many-instance-attributes,too-many-public-methods """ Client for accessing the Groundlight cloud service. Provides methods to create visual detectors, @@ -1089,6 +1110,139 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments inspection_id=inspection_id, ) + def ask_vlm( # pylint: disable=too-many-locals + self, + media: Union[ + np.ndarray, + str, + bytes, + Image.Image, + BytesIO, + BufferedReader, + List[Union[np.ndarray, str, bytes, Image.Image, BytesIO, BufferedReader]], + ], + query: str, + model_id: Optional[str] = None, + timeout: float = 15.0, + ) -> VLMVerificationResult: + """Verify one or more images against a natural-language query using a cloud VLM. + + Calls the Groundlight ``POST /v1/vlm-verifications`` endpoint. The VLM runs in the + Groundlight cloud (AWS Bedrock) — no local inference. + + The server makes no assumptions about what the images are — your ``query`` should + describe them. Images are presented to the model labeled ``Image 1``, ``Image 2``, + ... in the order given, so the query can refer to them. + + **Example usage**:: + + gl = Groundlight() + + # Single image + result = gl.ask_vlm(frame, query="Is there a fire in this image?") + if result.verdict == "YES": + emit_alert() + + # Full frame + cropped ROI — describe each in the query + result = gl.ask_vlm( + media=[full_frame, roi_crop], + query="Image 1 is the full camera frame; image 2 is the cropped region " + "a detector flagged. Is there really a fire?", + ) + print(result.confidence, result.reasoning) + + :param media: One image or a list of up to 8 images. Accepted formats per image: + + - filename (string) of a JPEG or PNG file (``".jpg"``, ``".jpeg"``, ``".png"``) + - raw bytes, BytesIO, or BufferedReader — sent as-is; the server decodes and + normalises to JPEG regardless of the declared content type, so PNG/WEBP bytes + all work + - numpy array (H, W, 3) in BGR order (OpenCV convention) — converted to JPEG + before sending + - PIL Image — converted to JPEG before sending + + :param query: Natural-language prompt describing the media and what to verify, + e.g. ``"Is there a fire visible in the image? Reason step by step."`` + :param model_id: Friendly alias of the VLM to use. The server is the source + of truth; passing an unrecognised alias returns HTTP 400. Currently + supported aliases: + + - ``"gpt-5.4"`` — OpenAI GPT-5.4 via Bedrock Responses API (default) + - ``"claude-sonnet-4.5"`` — Anthropic Claude Sonnet 4.5 + - ``"claude-haiku-3"`` — Anthropic Claude Haiku 3 + - ``"nova-pro"`` — Amazon Nova Pro + - ``"nova-lite"`` — Amazon Nova Lite + - ``"llama3.2-90b"`` — Meta Llama 3.2 90B + - ``"llama3.2-11b"`` — Meta Llama 3.2 11B + + Omit to use the server-configured default (currently ``"gpt-5.4"``). + :param timeout: Request timeout in seconds (default 15 s). + + :return: :class:`VLMVerificationResult` with ``verdict`` (``"YES"`` / ``"NO"`` / + ``"UNSURE"``), ``confidence``, ``reasoning``, and token cost fields. + :raises ValueError: If zero or more than ``MAX_VLM_MEDIA_ITEMS`` (8) images are supplied. + :raises requests.HTTPError: On non-2xx response (400 for invalid model alias + or undecodable image bytes; 502 if the upstream VLM is unavailable). + """ + # Normalise: single image → list + if not isinstance(media, list): + media = [media] + if not media: + raise ValueError("ask_vlm requires at least one media item.") + if len(media) > MAX_VLM_MEDIA_ITEMS: + raise ValueError(f"ask_vlm supports at most {MAX_VLM_MEDIA_ITEMS} media items.") + + # Encode each item. numpy/PIL → JPEG; bytes/BytesIO/BufferedReader → pass through + # (server calls ensure_jpeg_format and validates by decoding, so any common format works). + media_files: list[tuple[str, tuple[str, bytes, str]]] = [] + for i, img in enumerate(media): + stream = parse_supported_image_types(img) + jpeg_bytes = stream.read() + media_files.append(("media", (f"image_{i}.jpg", jpeg_bytes, "image/jpeg"))) + + # query and model_id are sent as multipart form fields (not query-string + # params): the prompt can be long and must not end up in URLs or access logs. + form_data: dict[str, str] = {"query": query} + if model_id: + form_data["model_id"] = model_id + + headers = { + "x-api-token": self.api_client.configuration.api_key["ApiToken"], + "X-Request-Id": f"ask_vlm_{time.time_ns()}", + "x-sdk-language": "python", + } + + # sanitize_endpoint_url may produce an endpoint that already ends with a + # version segment (e.g. ".../v1"). Strip it so we never produce ".../v1/v1/...". + base = re.sub(r"/v\d+$", "", self.endpoint) + url = f"{base}/v1/vlm-verifications" + + resp = requests.post( + url, + data=form_data, + files=media_files, + headers=headers, + timeout=timeout, + verify=self.api_client.configuration.verify_ssl, + ) + resp.raise_for_status() + data = resp.json() + + result_block = data.get("result", {}) + cost_block = data.get("cost", {}) + return VLMVerificationResult( + id=data.get("id", ""), + query=data.get("query", query), + model_id=data.get("model_id", model_id or ""), + verdict=result_block.get("verdict", "UNSURE"), + confidence=float(result_block.get("confidence", 0.0)), + reasoning=result_block.get("reasoning", ""), + created_at=data.get("created_at", ""), + input_tokens=cost_block.get("input_tokens"), + output_tokens=cost_block.get("output_tokens"), + total_cost_usd=cost_block.get("total_cost_usd"), + ) + def wait_for_confident_result( self, image_query: Union[ImageQuery, str], diff --git a/test/unit/test_ask_vlm.py b/test/unit/test_ask_vlm.py new file mode 100644 index 00000000..9b398108 --- /dev/null +++ b/test/unit/test_ask_vlm.py @@ -0,0 +1,108 @@ +"""Unit tests for Groundlight.ask_vlm — all HTTP mocked, no live server needed.""" + +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from groundlight import Groundlight, VLMVerificationResult +from groundlight.client import MAX_VLM_MEDIA_ITEMS +from groundlight.optional_imports import MISSING_NUMPY, np + +# Minimal valid-looking JPEG bytes for tests that don't exercise image encoding. +_FAKE_JPEG = b"\xff\xd8\xff\xe0" + b"\x00" * 16 + + +@pytest.fixture(name="gl") +def groundlight_fixture(monkeypatch) -> Groundlight: + monkeypatch.setenv("GROUNDLIGHT_API_TOKEN", "api_fake_test_token") + with patch.object(Groundlight, "_verify_connectivity", return_value=None): + return Groundlight(endpoint="http://test-server/device-api/") + + +def _mock_response(verdict="YES", confidence=0.92, reasoning="Flames visible.", model_id="gpt-5.4"): + resp = MagicMock() + resp.status_code = 201 + resp.json.return_value = { + "id": "vlmv_test123", + "type": "vlm_verification", + "created_at": "2025-06-17T00:00:00Z", + "query": "Is there a fire?", + "model_id": model_id, + "result": {"verdict": verdict, "confidence": confidence, "reasoning": reasoning}, + "cost": {"input_tokens": 400, "output_tokens": 80, "total_cost_usd": 0.0015}, + } + resp.raise_for_status = MagicMock() + return resp + + +def test_returns_vlm_verification_result(gl: Groundlight): + """Result fields are correctly unpacked from the server response JSON.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response() + result = gl.ask_vlm(media=_FAKE_JPEG, query="Is there a fire?") + + assert isinstance(result, VLMVerificationResult) + assert result.verdict == "YES" + assert result.confidence == pytest.approx(0.92) + assert result.id == "vlmv_test123" + assert result.total_cost_usd == pytest.approx(0.0015) + + +@pytest.mark.skipif(MISSING_NUMPY, reason="Needs numpy") +def test_numpy_image_encoded_as_jpeg_multipart(gl: Groundlight): + """A numpy array is converted to JPEG and sent as a multipart 'media' part.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response() + gl.ask_vlm(media=np.zeros((480, 640, 3), dtype=np.uint8), query="Is there a fire?") + + _, kwargs = mock_requests.post.call_args + files = kwargs["files"] + assert len(files) == 1 + assert files[0][0] == "media" + _name, data, ctype = files[0][1] + assert ctype == "image/jpeg" + assert len(data) > 0 + + +def test_query_sent_as_form_field_not_url_param(gl: Groundlight): + """query and model_id go in the multipart body — never the URL — so the prompt + doesn't leak into access logs.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response(model_id="nova-pro") + gl.ask_vlm(media=_FAKE_JPEG, query="Is there a fire?", model_id="nova-pro") + + _, kwargs = mock_requests.post.call_args + assert kwargs["data"]["query"] == "Is there a fire?" + assert kwargs["data"]["model_id"] == "nova-pro" + assert "params" not in kwargs or not kwargs.get("params") + + +def test_more_than_max_media_raises(gl: Groundlight): + """Supplying more than MAX_VLM_MEDIA_ITEMS raises ValueError before any network call.""" + with pytest.raises(ValueError, match=f"at most {MAX_VLM_MEDIA_ITEMS}"): + gl.ask_vlm(media=[_FAKE_JPEG] * (MAX_VLM_MEDIA_ITEMS + 1), query="test") + + +def test_url_has_correct_path(gl: Groundlight): + """sanitize_endpoint_url strips the trailing slash from self.endpoint, so the path + must include a leading '/' — without it the URL becomes '...device-apiv1/...'.""" + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response() + gl.ask_vlm(media=_FAKE_JPEG, query="test") + + args, _ = mock_requests.post.call_args + assert "/device-api/v1/vlm-verifications" in args[0] + + +def test_url_no_version_duplication_for_versioned_endpoint(monkeypatch): + """When the endpoint already ends with /v1 the URL must not contain /v1/v1/.""" + monkeypatch.setenv("GROUNDLIGHT_API_TOKEN", "api_fake_test_token") + with patch.object(Groundlight, "_verify_connectivity", return_value=None): + gl_v1 = Groundlight(endpoint="http://test-server/v1") + with mock.patch("groundlight.client.requests") as mock_requests: + mock_requests.post.return_value = _mock_response() + gl_v1.ask_vlm(media=_FAKE_JPEG, query="test") + args, _ = mock_requests.post.call_args + url = args[0] + assert "/v1/v1/" not in url + assert url.endswith("/v1/vlm-verifications")