Skip to content

GPU Worker SQS Pull Model Implementation Plan

For agentic workers: REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (- [ ]) syntax for tracking.

Goal: Eliminate the GPU worker's public port 8000 (cubic P0 on PR #548) by replacing Lambda→HTTP calls with an SQS pull model — the GPU worker polls a queue and has zero inbound network surface.

Architecture: Lambdas write the job payload to S3 (claim-check pattern, because base64 frames exceed SQS's 256KB limit), send a small SQS message, and poll S3 for a result object. The GPU worker long-polls SQS, runs inference with the already-loaded models, and writes results back to S3. Worker liveness is signaled by a heartbeat object in S3 (replaces HTTP /health); the client's existing EC2 start/adopt logic stays, only the readiness probe changes. Once cut over, the 0.0.0.0/0 port-8000 ingress rule and the GPU_WORKER_TOKEN machinery are deleted.

Tech Stack: Python 3.11, boto3, SQS, S3 (existing encache-raw-memory bucket), Terraform (main/devops/), SAM (main/server/template.yaml), moto for tests.

Global Constraints

  • Never skip pre-commit hooks; never use --no-verify (repo policy).
  • Conventional Commits format required for all commits and the PR title.
  • Python formatted with ruff format; mypy must pass (pre-commit runs both).
  • Do NOT deploy terraform apply or sam deploy as part of tasks — those are manual cutover steps listed at the end.
  • The existing public interface of VLM2VecClient (encode_video, encode_text, caption, generate — same signatures and return types) must not change; 6 call sites depend on it.
  • Bucket name: encache-raw-memory. New S3 prefixes: gpu-jobs/requests/, gpu-jobs/results/, heartbeat at gpu-jobs/heartbeat.json.
  • Queue names: encache-gpu-requests, DLQ encache-gpu-requests-dlq.
  • SSM parameter for queue URL: /encache/gpu/queue_url. Lambda env var: GPU_QUEUE_URL.

Current State (for context)

  • main/server/worldmm/gpu_worker/server.py — FastAPI app on port 8000. Route handlers encode_video(req), encode_text(req), caption(req), generate(req) are plain functions taking pydantic models; token auth is a FastAPI Depends so direct function calls bypass it. load_models(), _start_watchdog() (idle self-shutdown), _touch_activity() are module-level.
  • main/server/worldmm/memory/visual/encoder.pyVLM2VecClient does HTTP POSTs; _ensure_running() starts/adopts the EC2 instance and _wait_for_health() polls HTTP /health.
  • main/server/worldmm/gpu_worker/ec2_user_data.sh — at boot copies s3://encache-raw-memory/gpu-worker/server.py to /opt/vlm2vec/ and runs it via systemd unit vlm2vec.service.
  • main/devops/main.tfaws_security_group.gpu_worker opens 8000 to 0.0.0.0/0 (~line 200-218); aws_iam_role.gpu_worker (~line 126) already has S3 + SSM policies.
  • Call sites of VLM2VecClient: worldmm/pipeline/ingest_window.py, ingest_session.py, run.py, worldmm/retrieval/semantic_retriever.py, episodic_retriever.py, api/memories/chat/app.py.

Latency & Failure Semantics

  • Added latency per call: SQS send (~20ms) + worker long-poll pickup (near-0 when idle) + result poll interval (1s) ≈ 1–2s. Acceptable; the chat path already tolerates multi-minute GPU cold starts.
  • Inference errors: worker writes {"error": ...} result and deletes the message (no retry — deterministic failures shouldn't redrive). Client raises GpuJobError.
  • Worker crashes mid-job: message visibility timeout (300s) expires, SQS redelivers; after 3 receives → DLQ.
  • Heartbeat starts only AFTER models load, so "fresh heartbeat" means "ready for work" (same semantics as today's /health returning 200 only when models are loaded).

Task 1: Terraform — queue, DLQ, lifecycle rule, worker IAM, queue-URL SSM param

Files: - Modify: main/devops/main.tf

Interfaces: - Produces: SQS queue encache-gpu-requests (URL in SSM /encache/gpu/queue_url), DLQ, S3 lifecycle expiring gpu-jobs/ after 1 day, sqs:ReceiveMessage/DeleteMessage/GetQueueAttributes on the worker role.

  • [ ] Step 1: Add SQS + SSM + IAM + lifecycle resources

Add to main/devops/main.tf (near the existing gpu_worker IAM resources, ~line 170):

# ── GPU job queue (SQS pull model — worker has no inbound ports) ──

resource "aws_sqs_queue" "gpu_requests_dlq" {
  name                      = "encache-gpu-requests-dlq"
  message_retention_seconds = 1209600 # 14 days
}

resource "aws_sqs_queue" "gpu_requests" {
  name                       = "encache-gpu-requests"
  visibility_timeout_seconds = 300  # > longest single inference (caption ~60s) with margin
  receive_wait_time_seconds  = 20   # long polling
  message_retention_seconds  = 3600 # clients give up after ≤420s; stale jobs are useless
  redrive_policy = jsonencode({
    deadLetterTargetArn = aws_sqs_queue.gpu_requests_dlq.arn
    maxReceiveCount     = 3
  })
}

resource "aws_ssm_parameter" "gpu_queue_url" {
  name  = "/encache/gpu/queue_url"
  type  = "String"
  value = aws_sqs_queue.gpu_requests.url
}

resource "aws_iam_role_policy" "gpu_worker_sqs" {
  name = "gpu-worker-sqs"
  role = aws_iam_role.gpu_worker.id
  policy = jsonencode({
    Version = "2012-10-17"
    Statement = [{
      Effect   = "Allow"
      Action   = ["sqs:ReceiveMessage", "sqs:DeleteMessage", "sqs:GetQueueAttributes"]
      Resource = aws_sqs_queue.gpu_requests.arn
    }]
  })
}

resource "aws_s3_bucket_lifecycle_configuration" "raw_data_gpu_jobs" {
  bucket = aws_s3_bucket.raw_data.id
  rule {
    id     = "expire-gpu-jobs"
    status = "Enabled"
    filter {
      prefix = "gpu-jobs/"
    }
    expiration {
      days = 1
    }
  }
}

Caution: only one aws_s3_bucket_lifecycle_configuration may exist per bucket. Run grep -n "lifecycle_configuration" main/devops/main.tf first — if one already exists for raw_data, add the rule to it instead of creating a new resource.

  • [ ] Step 2: Verify the worker's existing S3 policy covers gpu-jobs/

Read aws_iam_role_policy.gpu_worker_s3 (~line 139). If it is scoped to specific prefixes rather than the whole bucket, add arn:aws:s3:::encache-raw-memory/gpu-jobs/* with s3:GetObject and s3:PutObject. The worker must read gpu-jobs/requests/* and write gpu-jobs/results/* and gpu-jobs/heartbeat.json.

  • [ ] Step 3: Validate

Run: cd main/devops && terraform fmt && terraform validate Expected: Success! The configuration is valid.

  • [ ] Step 4: Commit
git add main/devops/main.tf
git commit -m "feat(infra): add GPU job queue, DLQ, and gpu-jobs S3 lifecycle for SQS pull model"

Task 2: Shared job protocol module

Files: - Create: main/server/worldmm/gpu_worker/protocol.py - Create: main/server/worldmm/gpu_worker/__init__.py (empty, if missing) - Test: main/server/tests/unit/test_gpu_protocol.py

Interfaces: - Produces (used by Tasks 3 and 4): - submit_job(sqs, s3, *, queue_url: str, bucket: str, op: str, payload: dict) -> str — returns request_id - fetch_job(s3, *, bucket: str, request_id: str) -> dict — returns {"op": str, "payload": dict} - write_result(s3, *, bucket: str, request_id: str, result: dict | None = None, error: str | None = None) -> None - poll_result(s3, *, bucket: str, request_id: str, timeout_s: float, interval_s: float = 1.0) -> dict — raises GpuJobError on error result, TimeoutError on no result - write_heartbeat(s3, *, bucket: str) -> None - heartbeat_age_s(s3, *, bucket: str) -> float | NoneNone when no heartbeat object exists - class GpuJobError(RuntimeError) - Constants: REQUESTS_PREFIX, RESULTS_PREFIX, HEARTBEAT_KEY

  • [ ] Step 1: Write the failing tests
# main/server/tests/unit/test_gpu_protocol.py
"""Unit tests for the GPU job claim-check protocol (SQS + S3)."""

from __future__ import annotations

import json
import sys
from pathlib import Path

import boto3
import pytest
from moto import mock_aws

_SERVER_ROOT = Path(__file__).resolve().parents[2]
if str(_SERVER_ROOT) not in sys.path:
    sys.path.insert(0, str(_SERVER_ROOT))

from worldmm.gpu_worker import protocol  # noqa: E402

BUCKET = "test-gpu-jobs"


@pytest.fixture()
def aws():
    with mock_aws():
        s3 = boto3.client("s3", region_name="us-east-1")
        sqs = boto3.client("sqs", region_name="us-east-1")
        s3.create_bucket(Bucket=BUCKET)
        queue_url = sqs.create_queue(QueueName="gpu-requests")["QueueUrl"]
        yield s3, sqs, queue_url


def test_submit_then_fetch_roundtrip(aws):
    s3, sqs, queue_url = aws
    request_id = protocol.submit_job(
        sqs, s3, queue_url=queue_url, bucket=BUCKET,
        op="caption", payload={"frames": ["abc"], "transcript": "hi"},
    )

    msgs = sqs.receive_message(QueueUrl=queue_url)["Messages"]
    body = json.loads(msgs[0]["Body"])
    assert body["request_id"] == request_id
    assert body["op"] == "caption"

    job = protocol.fetch_job(s3, bucket=BUCKET, request_id=request_id)
    assert job == {"op": "caption", "payload": {"frames": ["abc"], "transcript": "hi"}}


def test_poll_result_returns_written_result(aws):
    s3, _, _ = aws
    protocol.write_result(s3, bucket=BUCKET, request_id="r1", result={"caption": "a dog"})
    out = protocol.poll_result(s3, bucket=BUCKET, request_id="r1", timeout_s=2, interval_s=0.05)
    assert out == {"caption": "a dog"}


def test_poll_result_raises_on_error_result(aws):
    s3, _, _ = aws
    protocol.write_result(s3, bucket=BUCKET, request_id="r2", error="ValueError: bad frames")
    with pytest.raises(protocol.GpuJobError, match="bad frames"):
        protocol.poll_result(s3, bucket=BUCKET, request_id="r2", timeout_s=2, interval_s=0.05)


def test_poll_result_times_out_when_no_result(aws):
    s3, _, _ = aws
    with pytest.raises(TimeoutError):
        protocol.poll_result(s3, bucket=BUCKET, request_id="missing", timeout_s=0.2, interval_s=0.05)


def test_heartbeat_age(aws):
    s3, _, _ = aws
    assert protocol.heartbeat_age_s(s3, bucket=BUCKET) is None
    protocol.write_heartbeat(s3, bucket=BUCKET)
    age = protocol.heartbeat_age_s(s3, bucket=BUCKET)
    assert age is not None and 0 <= age < 30
  • [ ] Step 2: Run tests to verify they fail

Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_protocol.py -v Expected: FAIL with ModuleNotFoundError / ImportError on worldmm.gpu_worker.protocol

  • [ ] Step 3: Implement protocol.py
# main/server/worldmm/gpu_worker/protocol.py
"""Claim-check job protocol between Lambdas and the GPU worker.

Payloads live in S3 (base64 frames exceed SQS's 256KB limit); the SQS
message carries only the request_id. Results and the worker heartbeat
are also S3 objects, so the worker needs zero inbound network surface.
"""

from __future__ import annotations

import json
import time
import uuid
from datetime import datetime, timezone

from botocore.exceptions import ClientError

REQUESTS_PREFIX = "gpu-jobs/requests/"
RESULTS_PREFIX = "gpu-jobs/results/"
HEARTBEAT_KEY = "gpu-jobs/heartbeat.json"


class GpuJobError(RuntimeError):
    """The worker processed the job and reported an inference error."""


def submit_job(
    sqs, s3, *, queue_url: str, bucket: str, op: str, payload: dict
) -> str:
    request_id = str(uuid.uuid4())
    s3.put_object(
        Bucket=bucket,
        Key=f"{REQUESTS_PREFIX}{request_id}.json",
        Body=json.dumps({"op": op, "payload": payload}).encode(),
    )
    sqs.send_message(
        QueueUrl=queue_url,
        MessageBody=json.dumps({"request_id": request_id, "op": op}),
    )
    return request_id


def fetch_job(s3, *, bucket: str, request_id: str) -> dict:
    obj = s3.get_object(Bucket=bucket, Key=f"{REQUESTS_PREFIX}{request_id}.json")
    return json.loads(obj["Body"].read())


def write_result(
    s3,
    *,
    bucket: str,
    request_id: str,
    result: dict | None = None,
    error: str | None = None,
) -> None:
    body: dict = {"error": error} if error is not None else {"result": result}
    s3.put_object(
        Bucket=bucket,
        Key=f"{RESULTS_PREFIX}{request_id}.json",
        Body=json.dumps(body).encode(),
    )


def poll_result(
    s3, *, bucket: str, request_id: str, timeout_s: float, interval_s: float = 1.0
) -> dict:
    key = f"{RESULTS_PREFIX}{request_id}.json"
    deadline = time.monotonic() + timeout_s
    while time.monotonic() < deadline:
        try:
            obj = s3.get_object(Bucket=bucket, Key=key)
        except ClientError as exc:
            if exc.response["Error"]["Code"] in ("NoSuchKey", "404"):
                time.sleep(interval_s)
                continue
            raise
        body = json.loads(obj["Body"].read())
        if "error" in body:
            raise GpuJobError(body["error"])
        return body["result"]
    raise TimeoutError(f"GPU job {request_id} produced no result within {timeout_s}s")


def write_heartbeat(s3, *, bucket: str) -> None:
    s3.put_object(Bucket=bucket, Key=HEARTBEAT_KEY, Body=b"{}")


def heartbeat_age_s(s3, *, bucket: str) -> float | None:
    """Seconds since the worker last heartbeat, or None if never."""
    try:
        obj = s3.head_object(Bucket=bucket, Key=HEARTBEAT_KEY)
    except ClientError:
        return None
    return (datetime.now(timezone.utc) - obj["LastModified"]).total_seconds()

Also touch main/server/worldmm/gpu_worker/__init__.py if it does not exist.

  • [ ] Step 4: Run tests to verify they pass

Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_protocol.py -v Expected: 5 PASS

  • [ ] Step 5: Commit
git add main/server/worldmm/gpu_worker/protocol.py main/server/worldmm/gpu_worker/__init__.py main/server/tests/unit/test_gpu_protocol.py
git commit -m "feat(gpu): add S3 claim-check job protocol for SQS pull model"

Task 3: Worker SQS poller

Files: - Create: main/server/worldmm/gpu_worker/sqs_worker.py - Test: main/server/tests/unit/test_gpu_sqs_worker.py

Interfaces: - Consumes: protocol.fetch_job, protocol.write_result, protocol.write_heartbeat (Task 2); server.load_models, server._start_watchdog, server._touch_activity, and the four route handler functions from server.py. - Produces: process_message(msg: dict, s3, *, bucket: str) -> None (unit-testable core), main() -> None entry point run by systemd. Env contract: GPU_QUEUE_URL (required), GPU_JOBS_BUCKET (default encache-raw-memory).

  • [ ] Step 1: Write the failing tests
# main/server/tests/unit/test_gpu_sqs_worker.py
"""Unit tests for the GPU worker's SQS message processing.

server.py imports torch at module load, so the sqs_worker module lazily
imports server inside its handlers; tests inject fake handlers instead.
"""

from __future__ import annotations

import json
import sys
from pathlib import Path

import boto3
import pytest
from moto import mock_aws

_SERVER_ROOT = Path(__file__).resolve().parents[2]
if str(_SERVER_ROOT) not in sys.path:
    sys.path.insert(0, str(_SERVER_ROOT))

from worldmm.gpu_worker import protocol, sqs_worker  # noqa: E402

BUCKET = "test-gpu-jobs"


@pytest.fixture()
def s3():
    with mock_aws():
        client = boto3.client("s3", region_name="us-east-1")
        client.create_bucket(Bucket=BUCKET)
        yield client


def _msg(request_id: str, op: str) -> dict:
    return {"Body": json.dumps({"request_id": request_id, "op": op}), "ReceiptHandle": "rh"}


def test_process_message_writes_result(s3, monkeypatch):
    monkeypatch.setattr(
        sqs_worker, "_HANDLERS", {"caption": lambda p: {"caption": f"saw {p['transcript']}"}}
    )
    s3.put_object(
        Bucket=BUCKET,
        Key=f"{protocol.REQUESTS_PREFIX}r1.json",
        Body=json.dumps({"op": "caption", "payload": {"transcript": "a dog"}}).encode(),
    )
    sqs_worker.process_message(_msg("r1", "caption"), s3, bucket=BUCKET)
    out = protocol.poll_result(s3, bucket=BUCKET, request_id="r1", timeout_s=1, interval_s=0.05)
    assert out == {"caption": "saw a dog"}


def test_process_message_writes_error_on_handler_exception(s3, monkeypatch):
    def boom(_payload):
        raise ValueError("bad frames")

    monkeypatch.setattr(sqs_worker, "_HANDLERS", {"caption": boom})
    s3.put_object(
        Bucket=BUCKET,
        Key=f"{protocol.REQUESTS_PREFIX}r2.json",
        Body=json.dumps({"op": "caption", "payload": {}}).encode(),
    )
    sqs_worker.process_message(_msg("r2", "caption"), s3, bucket=BUCKET)
    with pytest.raises(protocol.GpuJobError, match="ValueError: bad frames"):
        protocol.poll_result(s3, bucket=BUCKET, request_id="r2", timeout_s=1, interval_s=0.05)


def test_process_message_writes_error_on_unknown_op(s3):
    s3.put_object(
        Bucket=BUCKET,
        Key=f"{protocol.REQUESTS_PREFIX}r3.json",
        Body=json.dumps({"op": "nope", "payload": {}}).encode(),
    )
    sqs_worker.process_message(_msg("r3", "nope"), s3, bucket=BUCKET)
    with pytest.raises(protocol.GpuJobError, match="KeyError"):
        protocol.poll_result(s3, bucket=BUCKET, request_id="r3", timeout_s=1, interval_s=0.05)
  • [ ] Step 2: Run tests to verify they fail

Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_sqs_worker.py -v Expected: FAIL with ImportError: cannot import name 'sqs_worker'

  • [ ] Step 3: Implement sqs_worker.py
# main/server/worldmm/gpu_worker/sqs_worker.py
"""SQS pull-mode entry point for the GPU worker.

Replaces the public FastAPI server: the worker long-polls
encache-gpu-requests, runs inference via the handler functions in
server.py, and writes results to S3. Zero inbound network surface.

On the EC2 box this file sits flat in /opt/vlm2vec next to server.py
and protocol.py, hence the dual import paths below.
"""

from __future__ import annotations

import json
import os
import threading
import time
import traceback

import boto3

try:  # repo layout (tests, Lambda bundles)
    from worldmm.gpu_worker import protocol
except ImportError:  # flat layout on the EC2 box
    import protocol  # type: ignore[no-redef]

HEARTBEAT_INTERVAL_S = 30


def _load_handlers() -> dict:
    """Import server.py (heavy: torch, transformers) only at runtime."""
    try:
        from worldmm.gpu_worker import server
    except ImportError:
        import server  # type: ignore[no-redef]

    def _as_dict(resp) -> dict:
        return resp if isinstance(resp, dict) else resp.model_dump()

    return {
        "encode_video": lambda p: _as_dict(server.encode_video(server.EncodeVideoRequest(**p))),
        "encode_text": lambda p: _as_dict(server.encode_text(server.EncodeTextRequest(**p))),
        "caption": lambda p: _as_dict(server.caption(server.CaptionRequest(**p))),
        "generate": lambda p: _as_dict(server.generate(server.GenerateRequest(**p))),
    }


# Populated in main(); tests monkeypatch this directly.
_HANDLERS: dict = {}


def process_message(msg: dict, s3, *, bucket: str) -> None:
    # Poison message (bad JSON / missing request_id) raises: caller skips
    # delete, SQS redrives, DLQ after 3 receives.
    request_id = json.loads(msg["Body"])["request_id"]

    # fetch_job hits S3; transient get_object failures propagate (redrive),
    # they are not deterministic inference errors.
    job = protocol.fetch_job(s3, bucket=bucket, request_id=request_id)

    # Only handler execution is deterministic. write_result failures are
    # transient S3 errors and must propagate (no delete → redrive).
    try:
        result = _HANDLERS[job["op"]](job["payload"])
    except Exception as exc:  # deterministic failure — report, don't redrive
        traceback.print_exc()
        protocol.write_result(
            s3, bucket=bucket, request_id=request_id,
            error=f"{type(exc).__name__}: {exc}",
        )
    else:
        protocol.write_result(s3, bucket=bucket, request_id=request_id, result=result)


def _heartbeat_loop(s3, bucket: str) -> None:
    while True:
        try:
            protocol.write_heartbeat(s3, bucket=bucket)
        except Exception:
            traceback.print_exc()
        time.sleep(HEARTBEAT_INTERVAL_S)


def main() -> None:
    global _HANDLERS  # skipcq: PYL-W0603
    queue_url = os.environ["GPU_QUEUE_URL"]
    bucket = os.environ.get("GPU_JOBS_BUCKET", "encache-raw-memory")

    try:
        from worldmm.gpu_worker import server
    except ImportError:
        import server  # type: ignore[no-redef]

    s3 = boto3.client("s3")
    sqs = boto3.client("sqs")

    server.load_models()
    server._start_watchdog()  # idle self-shutdown, unchanged
    _HANDLERS = _load_handlers()
    # Heartbeat starts only now: fresh heartbeat == models loaded == ready.
    threading.Thread(target=_heartbeat_loop, args=(s3, bucket), daemon=True).start()
    print(f"SQS worker polling {queue_url}")

    while True:
        resp = sqs.receive_message(
            QueueUrl=queue_url, MaxNumberOfMessages=1, WaitTimeSeconds=20
        )
        for msg in resp.get("Messages", []):
            server._touch_activity()  # reset idle watchdog
            process_message(msg, s3, bucket=bucket)
            sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=msg["ReceiptHandle"])


if __name__ == "__main__":
    main()

Note: server.caption returns CaptionResponse (pydantic) or a dict depending on route style — _as_dict normalizes both. Verify server.py has a GenerateRequest model; if the /generate route takes a differently named model, match it here.

  • [ ] Step 4: Run tests to verify they pass

Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_sqs_worker.py -v Expected: 3 PASS

  • [ ] Step 5: Commit
git add main/server/worldmm/gpu_worker/sqs_worker.py main/server/tests/unit/test_gpu_sqs_worker.py
git commit -m "feat(gpu): add SQS pull-mode worker entry point"

Task 4: Client cutover — VLM2VecClient submits jobs instead of HTTP

Files: - Modify: main/server/worldmm/memory/visual/encoder.py - Test: main/server/tests/unit/test_vlm2vec_client_sqs.py

Interfaces: - Consumes: protocol.submit_job, protocol.poll_result, protocol.heartbeat_age_s (Task 2). - Produces: unchanged public API — encode_video(frames_b64: list[str]) -> list[float], encode_text(text: str) -> list[float], caption(frames_b64: list[str], transcript: str = "") -> str, generate(messages: list[dict], max_new_tokens: int = 512) -> str. Constructor keeps base_url param (now ignored) so the 6 call sites compile unchanged. Env contract: GPU_QUEUE_URL, GPU_JOBS_BUCKET (default encache-raw-memory).

  • [ ] Step 1: Write the failing tests
# main/server/tests/unit/test_vlm2vec_client_sqs.py
"""VLM2VecClient submits SQS jobs and polls S3 results (no HTTP)."""

from __future__ import annotations

import json
import sys
import threading
from pathlib import Path

import boto3
import pytest
from moto import mock_aws

_SERVER_ROOT = Path(__file__).resolve().parents[2]
_LAYER_ROOT = _SERVER_ROOT / "layers" / "shared" / "python"
for p in (_LAYER_ROOT, _SERVER_ROOT):
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))

from worldmm.gpu_worker import protocol  # noqa: E402
from worldmm.memory.visual.encoder import VLM2VecClient  # noqa: E402

BUCKET = "test-gpu-jobs"


@pytest.fixture()
def aws(monkeypatch):
    with mock_aws():
        s3 = boto3.client("s3", region_name="us-east-1")
        sqs = boto3.client("sqs", region_name="us-east-1")
        s3.create_bucket(Bucket=BUCKET)
        queue_url = sqs.create_queue(QueueName="gpu-requests")["QueueUrl"]
        monkeypatch.setenv("GPU_QUEUE_URL", queue_url)
        monkeypatch.setenv("GPU_JOBS_BUCKET", BUCKET)
        yield s3, sqs, queue_url


def _answer_next_job(s3, sqs, queue_url, result: dict) -> None:
    """Play the worker: consume one message and write the result."""
    def loop():
        for _ in range(100):
            msgs = sqs.receive_message(QueueUrl=queue_url, WaitTimeSeconds=1).get("Messages", [])
            if msgs:
                rid = json.loads(msgs[0]["Body"])["request_id"]
                protocol.write_result(s3, bucket=BUCKET, request_id=rid, result=result)
                return
    threading.Thread(target=loop, daemon=True).start()


def test_encode_text_roundtrip(aws, monkeypatch):
    s3, sqs, queue_url = aws
    client = VLM2VecClient(instance_id="i-fake")
    monkeypatch.setattr(client, "_ensure_running", lambda: None)
    monkeypatch.setattr(client, "_poll_interval_s", 0.05)

    _answer_next_job(s3, sqs, queue_url, {"embedding": [0.1, 0.2]})
    assert client.encode_text("hello") == [0.1, 0.2]

    # Payload made it into the request object
    msgs = s3.list_objects_v2(Bucket=BUCKET, Prefix=protocol.REQUESTS_PREFIX)
    body = json.loads(
        s3.get_object(Bucket=BUCKET, Key=msgs["Contents"][0]["Key"])["Body"].read()
    )
    assert body == {"op": "encode_text", "payload": {"text": "hello"}}


def test_caption_roundtrip(aws, monkeypatch):
    s3, sqs, queue_url = aws
    client = VLM2VecClient(instance_id="i-fake")
    monkeypatch.setattr(client, "_ensure_running", lambda: None)
    monkeypatch.setattr(client, "_poll_interval_s", 0.05)

    _answer_next_job(s3, sqs, queue_url, {"caption": "a sunny park"})
    assert client.caption(["frame1"], transcript="birds") == "a sunny park"


def test_worker_error_raises(aws, monkeypatch):
    s3, sqs, queue_url = aws
    client = VLM2VecClient(instance_id="i-fake")
    monkeypatch.setattr(client, "_ensure_running", lambda: None)
    monkeypatch.setattr(client, "_poll_interval_s", 0.05)
    monkeypatch.setattr(client, "_job_timeout_s", lambda op: 2)

    def answer_with_error():
        for _ in range(100):
            msgs = sqs.receive_message(QueueUrl=queue_url, WaitTimeSeconds=1).get("Messages", [])
            if msgs:
                rid = json.loads(msgs[0]["Body"])["request_id"]
                protocol.write_result(s3, bucket=BUCKET, request_id=rid, error="OOM")
                return
    threading.Thread(target=answer_with_error, daemon=True).start()

    with pytest.raises(protocol.GpuJobError, match="OOM"):
        client.encode_text("boom")
  • [ ] Step 2: Run tests to verify they fail

Run: cd main/server && .venv/bin/pytest tests/unit/test_vlm2vec_client_sqs.py -v Expected: FAIL (client still does HTTP; no _poll_interval_s attribute)

  • [ ] Step 3: Rewrite the transport in encoder.py

Keep everything about EC2 lifecycle (_start_or_create, _create_from_template, _find_existing_by_tag, _persist_instance_id_to_ssm) — only the probe and transport change. Replace the class header, health checks, and the four request methods:

# top of encoder.py — replace `import requests` usage for transport
from __future__ import annotations

import logging
import os
import time

import boto3

from worldmm.gpu_worker import protocol

logger = logging.getLogger(__name__)

STARTUP_TIMEOUT_S = 300  # instance boot (~2 min AMI) + model load, with margin
HEARTBEAT_FRESH_S = 90   # worker heartbeats every 30s; 3 missed = dead
JOB_TIMEOUTS_S = {
    "encode_video": 300,
    "encode_text": 240,
    "caption": 420,
    "generate": 300,
}


class VLM2VecClient:
    def __init__(
        self,
        base_url: str | None = None,  # deprecated, ignored (SQS transport)
        instance_id: str | None = None,
        aws_region: str = "us-east-1",
    ) -> None:
        self._instance_id = instance_id
        self._aws_region = aws_region
        self._queue_url = os.environ["GPU_QUEUE_URL"]
        self._bucket = os.environ.get("GPU_JOBS_BUCKET", "encache-raw-memory")
        self._s3 = boto3.client("s3", region_name=aws_region)
        self._sqs = boto3.client("sqs", region_name=aws_region)
        self._poll_interval_s = 1.0

    def _job_timeout_s(self, op: str) -> float:
        return JOB_TIMEOUTS_S[op]

    # ── liveness ────────────────────────────────────────────

    def _is_healthy(self) -> bool:
        age = protocol.heartbeat_age_s(self._s3, bucket=self._bucket)
        return age is not None and age < HEARTBEAT_FRESH_S

    def _wait_for_health(self) -> None:
        deadline = time.monotonic() + STARTUP_TIMEOUT_S
        while time.monotonic() < deadline:
            if self._is_healthy():
                logger.info("GPU worker heartbeat fresh — ready")
                return
            time.sleep(5)
        raise RuntimeError(
            f"GPU worker heartbeat not fresh within {STARTUP_TIMEOUT_S}s"
        )

    # _ensure_running / _start_or_create / _create_from_template /
    # _find_existing_by_tag / _persist_instance_id_to_ssm: UNCHANGED,
    # except: delete _update_base_url and every call to it (the client
    # no longer needs the instance IP).

    # ── transport ───────────────────────────────────────────

    def _submit(self, op: str, payload: dict) -> dict:
        self._ensure_running()
        request_id = protocol.submit_job(
            self._sqs, self._s3,
            queue_url=self._queue_url, bucket=self._bucket,
            op=op, payload=payload,
        )
        return protocol.poll_result(
            self._s3, bucket=self._bucket, request_id=request_id,
            timeout_s=self._job_timeout_s(op), interval_s=self._poll_interval_s,
        )

    def encode_video(self, frames_b64: list[str]) -> list[float]:
        return self._submit("encode_video", {"frames": frames_b64})["embedding"]

    def encode_text(self, text: str) -> list[float]:
        return self._submit("encode_text", {"text": text})["embedding"]

    def caption(self, frames_b64: list[str], transcript: str = "") -> str:
        return self._submit("caption", {"frames": frames_b64, "transcript": transcript})["caption"]

    def generate(self, messages: list[dict], max_new_tokens: int = 512) -> str:
        return self._submit(
            "generate", {"messages": messages, "max_new_tokens": max_new_tokens}
        )["text"]

Delete: _auth_headers, HEALTH_TIMEOUT_S/HEALTH_POLL_INTERVAL_S constants, _update_base_url, the requests import if nothing else uses it, and the old HTTP bodies of the four methods. In _ensure_running, the "probe base_url before bootstrapping" branch becomes a plain heartbeat check (if self._is_healthy(): return).

  • [ ] Step 4: Run new tests + existing suites

Run: cd main/server && .venv/bin/pytest tests/unit/test_vlm2vec_client_sqs.py tests/unit/ -v --tb=short Expected: new tests PASS. Existing tests that construct VLM2VecClient without GPU_QUEUE_URL in env will now KeyError — fix those tests by setting monkeypatch.setenv("GPU_QUEUE_URL", "https://sqs.fake/q") (integration tests mock the whole client class, so most are unaffected; _run_ingest in tests/integration/test_ingest_segment_first.py already patches worldmm.pipeline.ingest_window.VLM2VecClient).

  • [ ] Step 5: Check the call sites still make sense

grep -n "VLM2VecClient(" main/server -r. All 6 call sites pass base_url=/instance_id= kwargs — they compile unchanged because base_url is still accepted. In ingest_window.py, _resolve_gpu_url() is now only used to build an ignored argument — leave the simplification (deleting _resolve_gpu_url and the base_url= args) as an optional follow-up commit if tests stay green; it is not required for cutover.

  • [ ] Step 6: Commit
git add main/server/worldmm/memory/visual/encoder.py main/server/tests/unit/test_vlm2vec_client_sqs.py
git commit -m "feat(gpu): switch VLM2VecClient transport from HTTP to SQS claim-check"

Task 5: Wiring — Lambda env/policies, EC2 user_data, worker file upload

Files: - Modify: main/server/template.yaml - Modify: main/server/worldmm/gpu_worker/ec2_user_data.sh - Modify: the script that uploads server.py to s3://encache-raw-memory/gpu-worker/ (find with grep -rn "gpu-worker/server.py" scripts/ main/; likely scripts/build-gpu-ami.sh or a deploy script)

Interfaces: - Consumes: SSM /encache/gpu/queue_url (Task 1); sqs_worker.py, protocol.py (Tasks 2–3). - Produces: Lambdas get GPU_QUEUE_URL env + SQS send / S3 gpu-jobs permissions; EC2 boots sqs_worker.py under systemd.

  • [ ] Step 1: template.yaml — global env var

In the global Lambda environment block (where GPU_WORKER_TOKEN currently lives), add:

GPU_QUEUE_URL: !Sub "{{resolve:ssm:/encache/gpu/queue_url}}"
  • [ ] Step 2: template.yaml — per-function policies

Every function that constructs VLM2VecClient (ingest window, ingest session, chat, retriever-hosting functions — cross-check the 6 call-site files against their function definitions) needs:

- SQSSendMessagePolicy:
    QueueName: encache-gpu-requests
- S3CrudPolicy:
    BucketName: encache-raw-memory

Most already have S3CrudPolicy (or broader) on encache-raw-memory for frames — verify rather than duplicate.

  • [ ] Step 3: Validate SAM template

Run: cd main/server && sam validate --lint Expected: template is valid

  • [ ] Step 4: ec2_user_data.sh — fetch new files, new env, new ExecStart

Replace the GPU_WORKER_TOKEN SSM fetch (line ~18) with:

export GPU_QUEUE_URL=$(aws ssm get-parameter --name "/encache/gpu/queue_url" --query 'Parameter.Value' --output text --cli-connect-timeout 10 --cli-read-timeout 10)

After the existing server.py copy (line ~69), add:

aws s3 cp s3://encache-raw-memory/gpu-worker/protocol.py /opt/vlm2vec/protocol.py
aws s3 cp s3://encache-raw-memory/gpu-worker/sqs_worker.py /opt/vlm2vec/sqs_worker.py

In start.sh heredoc, change:

exec $PYTHON /opt/vlm2vec/sqs_worker.py

In the systemd unit, replace Environment="GPU_WORKER_TOKEN=${GPU_WORKER_TOKEN}" with:

Environment="GPU_QUEUE_URL=${GPU_QUEUE_URL}"
Environment="GPU_JOBS_BUCKET=encache-raw-memory"
  • [ ] Step 5: Update the upload script

Wherever server.py is uploaded to s3://encache-raw-memory/gpu-worker/, also upload protocol.py and sqs_worker.py:

aws s3 cp main/server/worldmm/gpu_worker/protocol.py s3://encache-raw-memory/gpu-worker/protocol.py
aws s3 cp main/server/worldmm/gpu_worker/sqs_worker.py s3://encache-raw-memory/gpu-worker/sqs_worker.py
  • [ ] Step 6: Commit
git add main/server/template.yaml main/server/worldmm/gpu_worker/ec2_user_data.sh scripts/
git commit -m "feat(infra): wire GPU_QUEUE_URL through SAM, user_data, and worker upload"

Task 6: Close the port and retire the token

Only merge/apply this after Tasks 1–5 are deployed and one end-to-end ingest has succeeded via SQS (see Rollout below).

Files: - Modify: main/devops/main.tf - Modify: main/devops/variables.tf - Modify: main/server/template.yaml - Modify: main/server/worldmm/gpu_worker/server.py - Modify: docs/docs/ (whichever page documents the GPU worker HTTP API — find with grep -rn "8000\|GPU_WORKER_TOKEN" docs/docs/)

  • [ ] Step 1: Remove the ingress rule

Delete the ingress block from aws_security_group.gpu_worker in main/devops/main.tf (~lines 205-211) and update the SG description to "GPU worker — egress only (SQS pull model)". Keep the egress block.

  • [ ] Step 2: Remove token infrastructure

  • main/devops/main.tf: delete aws_ssm_parameter.gpu_worker_token (~line 408).

  • main/devops/variables.tf: delete variable "gpu_worker_token".
  • main/devops/terraform.tfvars (local, gitignored): delete the gpu_worker_token line.
  • main/server/template.yaml: delete the GPU_WORKER_TOKEN global env var.
  • main/server/worldmm/gpu_worker/server.py: delete _GPU_WORKER_TOKEN, _check_token, the Depends/Header imports if now unused, and the dependencies=[Depends(_check_token)] arguments on the four routes. The FastAPI app remains only as a local-dev harness; add a module docstring line saying so.
  • main/server/worldmm/memory/visual/encoder.py: confirm _auth_headers was already removed in Task 4.

  • [ ] Step 3: Validate everything

Run: cd main/devops && terraform validate && cd ../server && sam validate --lint && .venv/bin/pytest tests/unit/ -q Expected: all green

  • [ ] Step 4: Update docs

Update the GPU worker page under docs/docs/ to describe the SQS pull architecture (queue names, S3 prefixes, heartbeat, no inbound ports). Delete references to port 8000 and the worker token.

  • [ ] Step 5: Commit
git add main/devops/ main/server/template.yaml main/server/worldmm/gpu_worker/server.py docs/docs/
git commit -m "feat(infra): close GPU worker port 8000 and retire worker token — SQS pull only"

Rollout (manual, in order)

  1. terraform apply (Task 1 resources) — queue exists, port still open, nothing breaks.
  2. aws_launch_template.gpu_worker in main/devops/main.tf does not manage user_data — Terraform will never push ec2_user_data.sh changes to the launch template. Create a new launch template version with the updated script embedded, then make it the default, e.g.:
    LT_ID=$(aws ec2 describe-launch-templates --launch-template-names encache-gpu-worker --query 'LaunchTemplates[0].LaunchTemplateId' --output text)
    NEW_VERSION=$(aws ec2 create-launch-template-version \
      --launch-template-id "$LT_ID" \
      --source-version '$Latest' \
      --launch-template-data "{\"UserData\":\"$(base64 -i main/server/worldmm/gpu_worker/ec2_user_data.sh)\"}" \
      --query 'LaunchTemplateVersion.VersionNumber' --output text)
    aws ec2 modify-launch-template --launch-template-id "$LT_ID" --default-version "$NEW_VERSION"
    
    Without this step, terminating the instance (next step) boots a fresh instance from the old default launch template version — old user_data, old HTTP server, silently wrong.
  3. Upload protocol.py + sqs_worker.py to s3://encache-raw-memory/gpu-worker/ (Task 5 script). Terminate the current GPU instance so the next boot picks up the new user_data/systemd config (or restart the systemd unit manually after copying files).
  4. sam deploy (Tasks 4–5 code) — Lambdas now submit via SQS. Brief window where an in-flight HTTP-era Lambda container could still try HTTP; acceptable (single-user product, retries via DLQ).
  5. Trigger one ingest end-to-end; confirm a segment completes and gpu-jobs/results/ objects appear.
  6. Merge + terraform apply Task 6 — port closed, token gone.
  7. Reply to cubic P0 comment on PR #548 linking this plan / the follow-up PR.

Verification checklist

  • [ ] aws ec2 describe-security-groups --group-names gpu-worker-sg shows no ingress rules.
  • [ ] Chat question round-trips (exercises encode_text over SQS).
  • [ ] Video ingest completes (exercises caption, generate, encode_video).
  • [ ] Idle worker still self-stops after 1h (watchdog unchanged); next request cold-starts it via heartbeat-miss → EC2 start.
  • [ ] DLQ empty after a day of normal use.