-
Notifications
You must be signed in to change notification settings - Fork 6
Add ask_vlm method for cloud VLM alert verification #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
fde518c
9a5e3e1
d3a428b
2b20fce
320887b
00789e0
263808d
3cfbb7e
7216313
6aad9e0
6d8b680
397f9cc
2ca16ff
b9bf222
2616477
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,14 @@ | ||
| # pylint: disable=too-many-lines | ||
| import logging | ||
| import os | ||
| import re | ||
| import time | ||
| import warnings | ||
| from functools import partial | ||
| from io import BufferedReader, BytesIO | ||
| from typing import Any, Callable, List, Optional, Tuple, Union | ||
|
|
||
| import requests | ||
| from groundlight_openapi_client import Configuration | ||
| from groundlight_openapi_client.api.detector_groups_api import DetectorGroupsApi | ||
| from groundlight_openapi_client.api.detectors_api import DetectorsApi | ||
|
|
@@ -37,6 +39,7 @@ | |
| PaginatedDetectorList, | ||
| PaginatedImageQueryList, | ||
| ) | ||
| from pydantic import BaseModel | ||
| from urllib3.exceptions import InsecureRequestWarning | ||
| from urllib3.util.retry import Retry | ||
|
|
||
|
|
@@ -73,6 +76,24 @@ class EdgeNotAvailableError(GroundlightClientError): | |
| """Raised when an edge-only method is called against a non-edge endpoint.""" | ||
|
|
||
|
|
||
| MAX_VLM_MEDIA_ITEMS = 8 | ||
|
|
||
|
|
||
| class VLMVerificationResult(BaseModel): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason we aren't modeling this in Swagger? |
||
| """Result of a VLM-based alert verification via the Groundlight cloud API.""" | ||
|
|
||
| id: str | ||
| query: str | ||
| model_id: str | ||
| verdict: str # "YES" | "NO" | "UNSURE" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we call this label or result instead of verdict? That would match established patterns better. But maybe this is meant to be intentionally different? It feels like we should be making use of existing objects, like
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not changed — VLM verification outputs YES/NO/UNSURE which is intentionally different from image-query binary/multiclass labels. Grouping verdict+confidence into a nested Result felt like premature reuse: the existing Result objects are generated from the detector image-query spec and carry detector-specific semantics. Keeping them flat and top-level on VLMVerificationResult avoids that coupling. Happy to revisit if there is a strong preference for nesting, or once the spec is auto-generated. 🤖 Addressed by Claude Code |
||
| confidence: float # 0.0–1.0 | ||
| reasoning: str | ||
| created_at: str | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this doesn't match how we handle created_at elsewhere in this repo. I think we use datetime or something more specific than str. |
||
| input_tokens: Optional[int] = None | ||
| output_tokens: Optional[int] = None | ||
| total_cost_usd: Optional[float] = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've heard it argued that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Left as float — total_cost_usd is an informational/display field echoed back from the server; the SDK performs no arithmetic on it, so Decimal's precision guarantees are not needed here. If the server ever returns this value in a way that requires client-side summing or billing math, switching to Decimal would be the right call. 🤖 Addressed by Claude Code |
||
|
|
||
|
|
||
| class Groundlight: # pylint: disable=too-many-instance-attributes,too-many-public-methods | ||
| """ | ||
| Client for accessing the Groundlight cloud service. Provides methods to create visual detectors, | ||
|
|
@@ -1089,6 +1110,139 @@ def ask_async( # noqa: PLR0913 # pylint: disable=too-many-arguments | |
| inspection_id=inspection_id, | ||
| ) | ||
|
|
||
| def ask_vlm( # pylint: disable=too-many-locals | ||
| self, | ||
| media: Union[ | ||
| np.ndarray, | ||
| str, | ||
| bytes, | ||
| Image.Image, | ||
| BytesIO, | ||
| BufferedReader, | ||
| List[Union[np.ndarray, str, bytes, Image.Image, BytesIO, BufferedReader]], | ||
| ], | ||
|
Copilot marked this conversation as resolved.
|
||
| query: str, | ||
| model_id: Optional[str] = None, | ||
| timeout: float = 15.0, | ||
| ) -> VLMVerificationResult: | ||
| """Verify one or more images against a natural-language query using a cloud VLM. | ||
|
|
||
| Calls the Groundlight ``POST /v1/vlm-verifications`` endpoint. The VLM runs in the | ||
| Groundlight cloud (AWS Bedrock) — no local inference. | ||
|
|
||
| The server makes no assumptions about what the images are — your ``query`` should | ||
| describe them. Images are presented to the model labeled ``Image 1``, ``Image 2``, | ||
| ... in the order given, so the query can refer to them. | ||
|
|
||
| **Example usage**:: | ||
|
|
||
| gl = Groundlight() | ||
|
|
||
| # Single image | ||
| result = gl.ask_vlm(frame, query="Is there a fire in this image?") | ||
| if result.verdict == "YES": | ||
| emit_alert() | ||
|
|
||
| # Full frame + cropped ROI — describe each in the query | ||
| result = gl.ask_vlm( | ||
| media=[full_frame, roi_crop], | ||
| query="Image 1 is the full camera frame; image 2 is the cropped region " | ||
| "a detector flagged. Is there really a fire?", | ||
| ) | ||
| print(result.confidence, result.reasoning) | ||
|
|
||
| :param media: One image or a list of up to 8 images. Accepted formats per image: | ||
|
|
||
| - filename (string) of a JPEG or PNG file (``".jpg"``, ``".jpeg"``, ``".png"``) | ||
| - raw bytes, BytesIO, or BufferedReader — sent as-is; the server decodes and | ||
| normalises to JPEG regardless of the declared content type, so PNG/WEBP bytes | ||
| all work | ||
| - numpy array (H, W, 3) in BGR order (OpenCV convention) — converted to JPEG | ||
| before sending | ||
| - PIL Image — converted to JPEG before sending | ||
|
|
||
| :param query: Natural-language prompt describing the media and what to verify, | ||
| e.g. ``"Is there a fire visible in the image? Reason step by step."`` | ||
| :param model_id: Friendly alias of the VLM to use. The server is the source | ||
| of truth; passing an unrecognised alias returns HTTP 400. Currently | ||
| supported aliases: | ||
|
|
||
| - ``"gpt-5.4"`` — OpenAI GPT-5.4 via Bedrock Responses API (default) | ||
| - ``"claude-sonnet-4.5"`` — Anthropic Claude Sonnet 4.5 | ||
| - ``"claude-haiku-3"`` — Anthropic Claude Haiku 3 | ||
| - ``"nova-pro"`` — Amazon Nova Pro | ||
| - ``"nova-lite"`` — Amazon Nova Lite | ||
| - ``"llama3.2-90b"`` — Meta Llama 3.2 90B | ||
| - ``"llama3.2-11b"`` — Meta Llama 3.2 11B | ||
|
|
||
| Omit to use the server-configured default (currently ``"gpt-5.4"``). | ||
| :param timeout: Request timeout in seconds (default 15 s). | ||
|
|
||
| :return: :class:`VLMVerificationResult` with ``verdict`` (``"YES"`` / ``"NO"`` / | ||
| ``"UNSURE"``), ``confidence``, ``reasoning``, and token cost fields. | ||
| :raises ValueError: If zero or more than ``MAX_VLM_MEDIA_ITEMS`` (8) images are supplied. | ||
| :raises requests.HTTPError: On non-2xx response (400 for invalid model alias | ||
| or undecodable image bytes; 502 if the upstream VLM is unavailable). | ||
| """ | ||
| # Normalise: single image → list | ||
| if not isinstance(media, list): | ||
| media = [media] | ||
| if not media: | ||
| raise ValueError("ask_vlm requires at least one media item.") | ||
| if len(media) > MAX_VLM_MEDIA_ITEMS: | ||
| raise ValueError(f"ask_vlm supports at most {MAX_VLM_MEDIA_ITEMS} media items.") | ||
|
|
||
| # Encode each item. numpy/PIL → JPEG; bytes/BytesIO/BufferedReader → pass through | ||
| # (server calls ensure_jpeg_format and validates by decoding, so any common format works). | ||
| media_files: list[tuple[str, tuple[str, bytes, str]]] = [] | ||
| for i, img in enumerate(media): | ||
| stream = parse_supported_image_types(img) | ||
| jpeg_bytes = stream.read() | ||
| media_files.append(("media", (f"image_{i}.jpg", jpeg_bytes, "image/jpeg"))) | ||
|
|
||
| # query and model_id are sent as multipart form fields (not query-string | ||
| # params): the prompt can be long and must not end up in URLs or access logs. | ||
| form_data: dict[str, str] = {"query": query} | ||
| if model_id: | ||
| form_data["model_id"] = model_id | ||
|
|
||
| headers = { | ||
| "x-api-token": self.api_client.configuration.api_key["ApiToken"], | ||
| "X-Request-Id": f"ask_vlm_{time.time_ns()}", | ||
| "x-sdk-language": "python", | ||
| } | ||
|
|
||
| # sanitize_endpoint_url may produce an endpoint that already ends with a | ||
| # version segment (e.g. ".../v1"). Strip it so we never produce ".../v1/v1/...". | ||
| base = re.sub(r"/v\d+$", "", self.endpoint) | ||
| url = f"{base}/v1/vlm-verifications" | ||
|
|
||
| resp = requests.post( | ||
| url, | ||
| data=form_data, | ||
| files=media_files, | ||
| headers=headers, | ||
| timeout=timeout, | ||
| verify=self.api_client.configuration.verify_ssl, | ||
| ) | ||
| resp.raise_for_status() | ||
| data = resp.json() | ||
|
|
||
| result_block = data.get("result", {}) | ||
| cost_block = data.get("cost", {}) | ||
| return VLMVerificationResult( | ||
| id=data.get("id", ""), | ||
| query=data.get("query", query), | ||
| model_id=data.get("model_id", model_id or ""), | ||
| verdict=result_block.get("verdict", "UNSURE"), | ||
| confidence=float(result_block.get("confidence", 0.0)), | ||
| reasoning=result_block.get("reasoning", ""), | ||
| created_at=data.get("created_at", ""), | ||
| input_tokens=cost_block.get("input_tokens"), | ||
| output_tokens=cost_block.get("output_tokens"), | ||
| total_cost_usd=cost_block.get("total_cost_usd"), | ||
| ) | ||
|
|
||
| def wait_for_confident_result( | ||
| self, | ||
| image_query: Union[ImageQuery, str], | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| """Unit tests for Groundlight.ask_vlm — all HTTP mocked, no live server needed.""" | ||
|
|
||
| from unittest import mock | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| import pytest | ||
| from groundlight import Groundlight, VLMVerificationResult | ||
| from groundlight.client import MAX_VLM_MEDIA_ITEMS | ||
| from groundlight.optional_imports import MISSING_NUMPY, np | ||
|
|
||
| # Minimal valid-looking JPEG bytes for tests that don't exercise image encoding. | ||
| _FAKE_JPEG = b"\xff\xd8\xff\xe0" + b"\x00" * 16 | ||
|
|
||
|
|
||
| @pytest.fixture(name="gl") | ||
| def groundlight_fixture(monkeypatch) -> Groundlight: | ||
| monkeypatch.setenv("GROUNDLIGHT_API_TOKEN", "api_fake_test_token") | ||
| with patch.object(Groundlight, "_verify_connectivity", return_value=None): | ||
| return Groundlight(endpoint="http://test-server/device-api/") | ||
|
|
||
|
|
||
| def _mock_response(verdict="YES", confidence=0.92, reasoning="Flames visible.", model_id="gpt-5.4"): | ||
| resp = MagicMock() | ||
| resp.status_code = 201 | ||
| resp.json.return_value = { | ||
| "id": "vlmv_test123", | ||
| "type": "vlm_verification", | ||
| "created_at": "2025-06-17T00:00:00Z", | ||
| "query": "Is there a fire?", | ||
| "model_id": model_id, | ||
| "result": {"verdict": verdict, "confidence": confidence, "reasoning": reasoning}, | ||
| "cost": {"input_tokens": 400, "output_tokens": 80, "total_cost_usd": 0.0015}, | ||
| } | ||
| resp.raise_for_status = MagicMock() | ||
| return resp | ||
|
|
||
|
|
||
| def test_returns_vlm_verification_result(gl: Groundlight): | ||
| """Result fields are correctly unpacked from the server response JSON.""" | ||
| with mock.patch("groundlight.client.requests") as mock_requests: | ||
| mock_requests.post.return_value = _mock_response() | ||
| result = gl.ask_vlm(media=_FAKE_JPEG, query="Is there a fire?") | ||
|
|
||
| assert isinstance(result, VLMVerificationResult) | ||
| assert result.verdict == "YES" | ||
| assert result.confidence == pytest.approx(0.92) | ||
| assert result.id == "vlmv_test123" | ||
| assert result.total_cost_usd == pytest.approx(0.0015) | ||
|
|
||
|
|
||
| @pytest.mark.skipif(MISSING_NUMPY, reason="Needs numpy") | ||
| def test_numpy_image_encoded_as_jpeg_multipart(gl: Groundlight): | ||
| """A numpy array is converted to JPEG and sent as a multipart 'media' part.""" | ||
| with mock.patch("groundlight.client.requests") as mock_requests: | ||
| mock_requests.post.return_value = _mock_response() | ||
| gl.ask_vlm(media=np.zeros((480, 640, 3), dtype=np.uint8), query="Is there a fire?") | ||
|
|
||
| _, kwargs = mock_requests.post.call_args | ||
| files = kwargs["files"] | ||
| assert len(files) == 1 | ||
| assert files[0][0] == "media" | ||
| _name, data, ctype = files[0][1] | ||
| assert ctype == "image/jpeg" | ||
| assert len(data) > 0 | ||
|
|
||
|
|
||
| def test_query_sent_as_form_field_not_url_param(gl: Groundlight): | ||
| """query and model_id go in the multipart body — never the URL — so the prompt | ||
| doesn't leak into access logs.""" | ||
| with mock.patch("groundlight.client.requests") as mock_requests: | ||
| mock_requests.post.return_value = _mock_response(model_id="nova-pro") | ||
| gl.ask_vlm(media=_FAKE_JPEG, query="Is there a fire?", model_id="nova-pro") | ||
|
|
||
| _, kwargs = mock_requests.post.call_args | ||
| assert kwargs["data"]["query"] == "Is there a fire?" | ||
| assert kwargs["data"]["model_id"] == "nova-pro" | ||
| assert "params" not in kwargs or not kwargs.get("params") | ||
|
|
||
|
|
||
| def test_more_than_max_media_raises(gl: Groundlight): | ||
| """Supplying more than MAX_VLM_MEDIA_ITEMS raises ValueError before any network call.""" | ||
| with pytest.raises(ValueError, match=f"at most {MAX_VLM_MEDIA_ITEMS}"): | ||
| gl.ask_vlm(media=[_FAKE_JPEG] * (MAX_VLM_MEDIA_ITEMS + 1), query="test") | ||
|
|
||
|
|
||
| def test_url_has_correct_path(gl: Groundlight): | ||
| """sanitize_endpoint_url strips the trailing slash from self.endpoint, so the path | ||
| must include a leading '/' — without it the URL becomes '...device-apiv1/...'.""" | ||
| with mock.patch("groundlight.client.requests") as mock_requests: | ||
| mock_requests.post.return_value = _mock_response() | ||
| gl.ask_vlm(media=_FAKE_JPEG, query="test") | ||
|
|
||
| args, _ = mock_requests.post.call_args | ||
| assert "/device-api/v1/vlm-verifications" in args[0] | ||
|
|
||
|
|
||
| def test_url_no_version_duplication_for_versioned_endpoint(monkeypatch): | ||
| """When the endpoint already ends with /v1 the URL must not contain /v1/v1/.""" | ||
| monkeypatch.setenv("GROUNDLIGHT_API_TOKEN", "api_fake_test_token") | ||
| with patch.object(Groundlight, "_verify_connectivity", return_value=None): | ||
| gl_v1 = Groundlight(endpoint="http://test-server/v1") | ||
| with mock.patch("groundlight.client.requests") as mock_requests: | ||
| mock_requests.post.return_value = _mock_response() | ||
| gl_v1.ask_vlm(media=_FAKE_JPEG, query="test") | ||
| args, _ = mock_requests.post.call_args | ||
| url = args[0] | ||
| assert "/v1/v1/" not in url | ||
| assert url.endswith("/v1/vlm-verifications") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably go in experimental