GPU Worker SQS Pull Model Implementation Plan
For agentic workers: REQUIRED SUB-SKILL: Use superpowers:subagent-driven-development (recommended) or superpowers:executing-plans to implement this plan task-by-task. Steps use checkbox (
- [ ]) syntax for tracking.
Goal: Eliminate the GPU worker's public port 8000 (cubic P0 on PR #548) by replacing Lambda→HTTP calls with an SQS pull model — the GPU worker polls a queue and has zero inbound network surface.
Architecture: Lambdas write the job payload to S3 (claim-check pattern, because base64 frames exceed SQS's 256KB limit), send a small SQS message, and poll S3 for a result object. The GPU worker long-polls SQS, runs inference with the already-loaded models, and writes results back to S3. Worker liveness is signaled by a heartbeat object in S3 (replaces HTTP /health); the client's existing EC2 start/adopt logic stays, only the readiness probe changes. Once cut over, the 0.0.0.0/0 port-8000 ingress rule and the GPU_WORKER_TOKEN machinery are deleted.
Tech Stack: Python 3.11, boto3, SQS, S3 (existing encache-raw-memory bucket), Terraform (main/devops/), SAM (main/server/template.yaml), moto for tests.
Global Constraints
- Never skip pre-commit hooks; never use
--no-verify(repo policy). - Conventional Commits format required for all commits and the PR title.
- Python formatted with
ruff format;mypymust pass (pre-commit runs both). - Do NOT deploy
terraform applyorsam deployas part of tasks — those are manual cutover steps listed at the end. - The existing public interface of
VLM2VecClient(encode_video,encode_text,caption,generate— same signatures and return types) must not change; 6 call sites depend on it. - Bucket name:
encache-raw-memory. New S3 prefixes:gpu-jobs/requests/,gpu-jobs/results/, heartbeat atgpu-jobs/heartbeat.json. - Queue names:
encache-gpu-requests, DLQencache-gpu-requests-dlq. - SSM parameter for queue URL:
/encache/gpu/queue_url. Lambda env var:GPU_QUEUE_URL.
Current State (for context)
main/server/worldmm/gpu_worker/server.py— FastAPI app on port 8000. Route handlersencode_video(req),encode_text(req),caption(req),generate(req)are plain functions taking pydantic models; token auth is a FastAPIDependsso direct function calls bypass it.load_models(),_start_watchdog()(idle self-shutdown),_touch_activity()are module-level.main/server/worldmm/memory/visual/encoder.py—VLM2VecClientdoes HTTP POSTs;_ensure_running()starts/adopts the EC2 instance and_wait_for_health()polls HTTP/health.main/server/worldmm/gpu_worker/ec2_user_data.sh— at boot copiess3://encache-raw-memory/gpu-worker/server.pyto/opt/vlm2vec/and runs it via systemd unitvlm2vec.service.main/devops/main.tf—aws_security_group.gpu_workeropens 8000 to0.0.0.0/0(~line 200-218);aws_iam_role.gpu_worker(~line 126) already has S3 + SSM policies.- Call sites of
VLM2VecClient:worldmm/pipeline/ingest_window.py,ingest_session.py,run.py,worldmm/retrieval/semantic_retriever.py,episodic_retriever.py,api/memories/chat/app.py.
Latency & Failure Semantics
- Added latency per call: SQS send (~20ms) + worker long-poll pickup (near-0 when idle) + result poll interval (1s) ≈ 1–2s. Acceptable; the chat path already tolerates multi-minute GPU cold starts.
- Inference errors: worker writes
{"error": ...}result and deletes the message (no retry — deterministic failures shouldn't redrive). Client raisesGpuJobError. - Worker crashes mid-job: message visibility timeout (300s) expires, SQS redelivers; after 3 receives → DLQ.
- Heartbeat starts only AFTER models load, so "fresh heartbeat" means "ready for work" (same semantics as today's
/healthreturning 200 only when models are loaded).
Task 1: Terraform — queue, DLQ, lifecycle rule, worker IAM, queue-URL SSM param
Files: - Modify: main/devops/main.tf
Interfaces: - Produces: SQS queue encache-gpu-requests (URL in SSM /encache/gpu/queue_url), DLQ, S3 lifecycle expiring gpu-jobs/ after 1 day, sqs:ReceiveMessage/DeleteMessage/GetQueueAttributes on the worker role.
- [ ] Step 1: Add SQS + SSM + IAM + lifecycle resources
Add to main/devops/main.tf (near the existing gpu_worker IAM resources, ~line 170):
# ── GPU job queue (SQS pull model — worker has no inbound ports) ──
resource "aws_sqs_queue" "gpu_requests_dlq" {
name = "encache-gpu-requests-dlq"
message_retention_seconds = 1209600 # 14 days
}
resource "aws_sqs_queue" "gpu_requests" {
name = "encache-gpu-requests"
visibility_timeout_seconds = 300 # > longest single inference (caption ~60s) with margin
receive_wait_time_seconds = 20 # long polling
message_retention_seconds = 3600 # clients give up after ≤420s; stale jobs are useless
redrive_policy = jsonencode({
deadLetterTargetArn = aws_sqs_queue.gpu_requests_dlq.arn
maxReceiveCount = 3
})
}
resource "aws_ssm_parameter" "gpu_queue_url" {
name = "/encache/gpu/queue_url"
type = "String"
value = aws_sqs_queue.gpu_requests.url
}
resource "aws_iam_role_policy" "gpu_worker_sqs" {
name = "gpu-worker-sqs"
role = aws_iam_role.gpu_worker.id
policy = jsonencode({
Version = "2012-10-17"
Statement = [{
Effect = "Allow"
Action = ["sqs:ReceiveMessage", "sqs:DeleteMessage", "sqs:GetQueueAttributes"]
Resource = aws_sqs_queue.gpu_requests.arn
}]
})
}
resource "aws_s3_bucket_lifecycle_configuration" "raw_data_gpu_jobs" {
bucket = aws_s3_bucket.raw_data.id
rule {
id = "expire-gpu-jobs"
status = "Enabled"
filter {
prefix = "gpu-jobs/"
}
expiration {
days = 1
}
}
}
Caution: only one aws_s3_bucket_lifecycle_configuration may exist per bucket. Run grep -n "lifecycle_configuration" main/devops/main.tf first — if one already exists for raw_data, add the rule to it instead of creating a new resource.
- [ ] Step 2: Verify the worker's existing S3 policy covers
gpu-jobs/
Read aws_iam_role_policy.gpu_worker_s3 (~line 139). If it is scoped to specific prefixes rather than the whole bucket, add arn:aws:s3:::encache-raw-memory/gpu-jobs/* with s3:GetObject and s3:PutObject. The worker must read gpu-jobs/requests/* and write gpu-jobs/results/* and gpu-jobs/heartbeat.json.
- [ ] Step 3: Validate
Run: cd main/devops && terraform fmt && terraform validate Expected: Success! The configuration is valid.
- [ ] Step 4: Commit
git add main/devops/main.tf
git commit -m "feat(infra): add GPU job queue, DLQ, and gpu-jobs S3 lifecycle for SQS pull model"
Task 2: Shared job protocol module
Files: - Create: main/server/worldmm/gpu_worker/protocol.py - Create: main/server/worldmm/gpu_worker/__init__.py (empty, if missing) - Test: main/server/tests/unit/test_gpu_protocol.py
Interfaces: - Produces (used by Tasks 3 and 4): - submit_job(sqs, s3, *, queue_url: str, bucket: str, op: str, payload: dict) -> str — returns request_id - fetch_job(s3, *, bucket: str, request_id: str) -> dict — returns {"op": str, "payload": dict} - write_result(s3, *, bucket: str, request_id: str, result: dict | None = None, error: str | None = None) -> None - poll_result(s3, *, bucket: str, request_id: str, timeout_s: float, interval_s: float = 1.0) -> dict — raises GpuJobError on error result, TimeoutError on no result - write_heartbeat(s3, *, bucket: str) -> None - heartbeat_age_s(s3, *, bucket: str) -> float | None — None when no heartbeat object exists - class GpuJobError(RuntimeError) - Constants: REQUESTS_PREFIX, RESULTS_PREFIX, HEARTBEAT_KEY
- [ ] Step 1: Write the failing tests
# main/server/tests/unit/test_gpu_protocol.py
"""Unit tests for the GPU job claim-check protocol (SQS + S3)."""
from __future__ import annotations
import json
import sys
from pathlib import Path
import boto3
import pytest
from moto import mock_aws
_SERVER_ROOT = Path(__file__).resolve().parents[2]
if str(_SERVER_ROOT) not in sys.path:
sys.path.insert(0, str(_SERVER_ROOT))
from worldmm.gpu_worker import protocol # noqa: E402
BUCKET = "test-gpu-jobs"
@pytest.fixture()
def aws():
with mock_aws():
s3 = boto3.client("s3", region_name="us-east-1")
sqs = boto3.client("sqs", region_name="us-east-1")
s3.create_bucket(Bucket=BUCKET)
queue_url = sqs.create_queue(QueueName="gpu-requests")["QueueUrl"]
yield s3, sqs, queue_url
def test_submit_then_fetch_roundtrip(aws):
s3, sqs, queue_url = aws
request_id = protocol.submit_job(
sqs, s3, queue_url=queue_url, bucket=BUCKET,
op="caption", payload={"frames": ["abc"], "transcript": "hi"},
)
msgs = sqs.receive_message(QueueUrl=queue_url)["Messages"]
body = json.loads(msgs[0]["Body"])
assert body["request_id"] == request_id
assert body["op"] == "caption"
job = protocol.fetch_job(s3, bucket=BUCKET, request_id=request_id)
assert job == {"op": "caption", "payload": {"frames": ["abc"], "transcript": "hi"}}
def test_poll_result_returns_written_result(aws):
s3, _, _ = aws
protocol.write_result(s3, bucket=BUCKET, request_id="r1", result={"caption": "a dog"})
out = protocol.poll_result(s3, bucket=BUCKET, request_id="r1", timeout_s=2, interval_s=0.05)
assert out == {"caption": "a dog"}
def test_poll_result_raises_on_error_result(aws):
s3, _, _ = aws
protocol.write_result(s3, bucket=BUCKET, request_id="r2", error="ValueError: bad frames")
with pytest.raises(protocol.GpuJobError, match="bad frames"):
protocol.poll_result(s3, bucket=BUCKET, request_id="r2", timeout_s=2, interval_s=0.05)
def test_poll_result_times_out_when_no_result(aws):
s3, _, _ = aws
with pytest.raises(TimeoutError):
protocol.poll_result(s3, bucket=BUCKET, request_id="missing", timeout_s=0.2, interval_s=0.05)
def test_heartbeat_age(aws):
s3, _, _ = aws
assert protocol.heartbeat_age_s(s3, bucket=BUCKET) is None
protocol.write_heartbeat(s3, bucket=BUCKET)
age = protocol.heartbeat_age_s(s3, bucket=BUCKET)
assert age is not None and 0 <= age < 30
- [ ] Step 2: Run tests to verify they fail
Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_protocol.py -v Expected: FAIL with ModuleNotFoundError / ImportError on worldmm.gpu_worker.protocol
- [ ] Step 3: Implement protocol.py
# main/server/worldmm/gpu_worker/protocol.py
"""Claim-check job protocol between Lambdas and the GPU worker.
Payloads live in S3 (base64 frames exceed SQS's 256KB limit); the SQS
message carries only the request_id. Results and the worker heartbeat
are also S3 objects, so the worker needs zero inbound network surface.
"""
from __future__ import annotations
import json
import time
import uuid
from datetime import datetime, timezone
from botocore.exceptions import ClientError
REQUESTS_PREFIX = "gpu-jobs/requests/"
RESULTS_PREFIX = "gpu-jobs/results/"
HEARTBEAT_KEY = "gpu-jobs/heartbeat.json"
class GpuJobError(RuntimeError):
"""The worker processed the job and reported an inference error."""
def submit_job(
sqs, s3, *, queue_url: str, bucket: str, op: str, payload: dict
) -> str:
request_id = str(uuid.uuid4())
s3.put_object(
Bucket=bucket,
Key=f"{REQUESTS_PREFIX}{request_id}.json",
Body=json.dumps({"op": op, "payload": payload}).encode(),
)
sqs.send_message(
QueueUrl=queue_url,
MessageBody=json.dumps({"request_id": request_id, "op": op}),
)
return request_id
def fetch_job(s3, *, bucket: str, request_id: str) -> dict:
obj = s3.get_object(Bucket=bucket, Key=f"{REQUESTS_PREFIX}{request_id}.json")
return json.loads(obj["Body"].read())
def write_result(
s3,
*,
bucket: str,
request_id: str,
result: dict | None = None,
error: str | None = None,
) -> None:
body: dict = {"error": error} if error is not None else {"result": result}
s3.put_object(
Bucket=bucket,
Key=f"{RESULTS_PREFIX}{request_id}.json",
Body=json.dumps(body).encode(),
)
def poll_result(
s3, *, bucket: str, request_id: str, timeout_s: float, interval_s: float = 1.0
) -> dict:
key = f"{RESULTS_PREFIX}{request_id}.json"
deadline = time.monotonic() + timeout_s
while time.monotonic() < deadline:
try:
obj = s3.get_object(Bucket=bucket, Key=key)
except ClientError as exc:
if exc.response["Error"]["Code"] in ("NoSuchKey", "404"):
time.sleep(interval_s)
continue
raise
body = json.loads(obj["Body"].read())
if "error" in body:
raise GpuJobError(body["error"])
return body["result"]
raise TimeoutError(f"GPU job {request_id} produced no result within {timeout_s}s")
def write_heartbeat(s3, *, bucket: str) -> None:
s3.put_object(Bucket=bucket, Key=HEARTBEAT_KEY, Body=b"{}")
def heartbeat_age_s(s3, *, bucket: str) -> float | None:
"""Seconds since the worker last heartbeat, or None if never."""
try:
obj = s3.head_object(Bucket=bucket, Key=HEARTBEAT_KEY)
except ClientError:
return None
return (datetime.now(timezone.utc) - obj["LastModified"]).total_seconds()
Also touch main/server/worldmm/gpu_worker/__init__.py if it does not exist.
- [ ] Step 4: Run tests to verify they pass
Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_protocol.py -v Expected: 5 PASS
- [ ] Step 5: Commit
git add main/server/worldmm/gpu_worker/protocol.py main/server/worldmm/gpu_worker/__init__.py main/server/tests/unit/test_gpu_protocol.py
git commit -m "feat(gpu): add S3 claim-check job protocol for SQS pull model"
Task 3: Worker SQS poller
Files: - Create: main/server/worldmm/gpu_worker/sqs_worker.py - Test: main/server/tests/unit/test_gpu_sqs_worker.py
Interfaces: - Consumes: protocol.fetch_job, protocol.write_result, protocol.write_heartbeat (Task 2); server.load_models, server._start_watchdog, server._touch_activity, and the four route handler functions from server.py. - Produces: process_message(msg: dict, s3, *, bucket: str) -> None (unit-testable core), main() -> None entry point run by systemd. Env contract: GPU_QUEUE_URL (required), GPU_JOBS_BUCKET (default encache-raw-memory).
- [ ] Step 1: Write the failing tests
# main/server/tests/unit/test_gpu_sqs_worker.py
"""Unit tests for the GPU worker's SQS message processing.
server.py imports torch at module load, so the sqs_worker module lazily
imports server inside its handlers; tests inject fake handlers instead.
"""
from __future__ import annotations
import json
import sys
from pathlib import Path
import boto3
import pytest
from moto import mock_aws
_SERVER_ROOT = Path(__file__).resolve().parents[2]
if str(_SERVER_ROOT) not in sys.path:
sys.path.insert(0, str(_SERVER_ROOT))
from worldmm.gpu_worker import protocol, sqs_worker # noqa: E402
BUCKET = "test-gpu-jobs"
@pytest.fixture()
def s3():
with mock_aws():
client = boto3.client("s3", region_name="us-east-1")
client.create_bucket(Bucket=BUCKET)
yield client
def _msg(request_id: str, op: str) -> dict:
return {"Body": json.dumps({"request_id": request_id, "op": op}), "ReceiptHandle": "rh"}
def test_process_message_writes_result(s3, monkeypatch):
monkeypatch.setattr(
sqs_worker, "_HANDLERS", {"caption": lambda p: {"caption": f"saw {p['transcript']}"}}
)
s3.put_object(
Bucket=BUCKET,
Key=f"{protocol.REQUESTS_PREFIX}r1.json",
Body=json.dumps({"op": "caption", "payload": {"transcript": "a dog"}}).encode(),
)
sqs_worker.process_message(_msg("r1", "caption"), s3, bucket=BUCKET)
out = protocol.poll_result(s3, bucket=BUCKET, request_id="r1", timeout_s=1, interval_s=0.05)
assert out == {"caption": "saw a dog"}
def test_process_message_writes_error_on_handler_exception(s3, monkeypatch):
def boom(_payload):
raise ValueError("bad frames")
monkeypatch.setattr(sqs_worker, "_HANDLERS", {"caption": boom})
s3.put_object(
Bucket=BUCKET,
Key=f"{protocol.REQUESTS_PREFIX}r2.json",
Body=json.dumps({"op": "caption", "payload": {}}).encode(),
)
sqs_worker.process_message(_msg("r2", "caption"), s3, bucket=BUCKET)
with pytest.raises(protocol.GpuJobError, match="ValueError: bad frames"):
protocol.poll_result(s3, bucket=BUCKET, request_id="r2", timeout_s=1, interval_s=0.05)
def test_process_message_writes_error_on_unknown_op(s3):
s3.put_object(
Bucket=BUCKET,
Key=f"{protocol.REQUESTS_PREFIX}r3.json",
Body=json.dumps({"op": "nope", "payload": {}}).encode(),
)
sqs_worker.process_message(_msg("r3", "nope"), s3, bucket=BUCKET)
with pytest.raises(protocol.GpuJobError, match="KeyError"):
protocol.poll_result(s3, bucket=BUCKET, request_id="r3", timeout_s=1, interval_s=0.05)
- [ ] Step 2: Run tests to verify they fail
Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_sqs_worker.py -v Expected: FAIL with ImportError: cannot import name 'sqs_worker'
- [ ] Step 3: Implement sqs_worker.py
# main/server/worldmm/gpu_worker/sqs_worker.py
"""SQS pull-mode entry point for the GPU worker.
Replaces the public FastAPI server: the worker long-polls
encache-gpu-requests, runs inference via the handler functions in
server.py, and writes results to S3. Zero inbound network surface.
On the EC2 box this file sits flat in /opt/vlm2vec next to server.py
and protocol.py, hence the dual import paths below.
"""
from __future__ import annotations
import json
import os
import threading
import time
import traceback
import boto3
try: # repo layout (tests, Lambda bundles)
from worldmm.gpu_worker import protocol
except ImportError: # flat layout on the EC2 box
import protocol # type: ignore[no-redef]
HEARTBEAT_INTERVAL_S = 30
def _load_handlers() -> dict:
"""Import server.py (heavy: torch, transformers) only at runtime."""
try:
from worldmm.gpu_worker import server
except ImportError:
import server # type: ignore[no-redef]
def _as_dict(resp) -> dict:
return resp if isinstance(resp, dict) else resp.model_dump()
return {
"encode_video": lambda p: _as_dict(server.encode_video(server.EncodeVideoRequest(**p))),
"encode_text": lambda p: _as_dict(server.encode_text(server.EncodeTextRequest(**p))),
"caption": lambda p: _as_dict(server.caption(server.CaptionRequest(**p))),
"generate": lambda p: _as_dict(server.generate(server.GenerateRequest(**p))),
}
# Populated in main(); tests monkeypatch this directly.
_HANDLERS: dict = {}
def process_message(msg: dict, s3, *, bucket: str) -> None:
# Poison message (bad JSON / missing request_id) raises: caller skips
# delete, SQS redrives, DLQ after 3 receives.
request_id = json.loads(msg["Body"])["request_id"]
# fetch_job hits S3; transient get_object failures propagate (redrive),
# they are not deterministic inference errors.
job = protocol.fetch_job(s3, bucket=bucket, request_id=request_id)
# Only handler execution is deterministic. write_result failures are
# transient S3 errors and must propagate (no delete → redrive).
try:
result = _HANDLERS[job["op"]](job["payload"])
except Exception as exc: # deterministic failure — report, don't redrive
traceback.print_exc()
protocol.write_result(
s3, bucket=bucket, request_id=request_id,
error=f"{type(exc).__name__}: {exc}",
)
else:
protocol.write_result(s3, bucket=bucket, request_id=request_id, result=result)
def _heartbeat_loop(s3, bucket: str) -> None:
while True:
try:
protocol.write_heartbeat(s3, bucket=bucket)
except Exception:
traceback.print_exc()
time.sleep(HEARTBEAT_INTERVAL_S)
def main() -> None:
global _HANDLERS # skipcq: PYL-W0603
queue_url = os.environ["GPU_QUEUE_URL"]
bucket = os.environ.get("GPU_JOBS_BUCKET", "encache-raw-memory")
try:
from worldmm.gpu_worker import server
except ImportError:
import server # type: ignore[no-redef]
s3 = boto3.client("s3")
sqs = boto3.client("sqs")
server.load_models()
server._start_watchdog() # idle self-shutdown, unchanged
_HANDLERS = _load_handlers()
# Heartbeat starts only now: fresh heartbeat == models loaded == ready.
threading.Thread(target=_heartbeat_loop, args=(s3, bucket), daemon=True).start()
print(f"SQS worker polling {queue_url}")
while True:
resp = sqs.receive_message(
QueueUrl=queue_url, MaxNumberOfMessages=1, WaitTimeSeconds=20
)
for msg in resp.get("Messages", []):
server._touch_activity() # reset idle watchdog
process_message(msg, s3, bucket=bucket)
sqs.delete_message(QueueUrl=queue_url, ReceiptHandle=msg["ReceiptHandle"])
if __name__ == "__main__":
main()
Note: server.caption returns CaptionResponse (pydantic) or a dict depending on route style — _as_dict normalizes both. Verify server.py has a GenerateRequest model; if the /generate route takes a differently named model, match it here.
- [ ] Step 4: Run tests to verify they pass
Run: cd main/server && .venv/bin/pytest tests/unit/test_gpu_sqs_worker.py -v Expected: 3 PASS
- [ ] Step 5: Commit
git add main/server/worldmm/gpu_worker/sqs_worker.py main/server/tests/unit/test_gpu_sqs_worker.py
git commit -m "feat(gpu): add SQS pull-mode worker entry point"
Task 4: Client cutover — VLM2VecClient submits jobs instead of HTTP
Files: - Modify: main/server/worldmm/memory/visual/encoder.py - Test: main/server/tests/unit/test_vlm2vec_client_sqs.py
Interfaces: - Consumes: protocol.submit_job, protocol.poll_result, protocol.heartbeat_age_s (Task 2). - Produces: unchanged public API — encode_video(frames_b64: list[str]) -> list[float], encode_text(text: str) -> list[float], caption(frames_b64: list[str], transcript: str = "") -> str, generate(messages: list[dict], max_new_tokens: int = 512) -> str. Constructor keeps base_url param (now ignored) so the 6 call sites compile unchanged. Env contract: GPU_QUEUE_URL, GPU_JOBS_BUCKET (default encache-raw-memory).
- [ ] Step 1: Write the failing tests
# main/server/tests/unit/test_vlm2vec_client_sqs.py
"""VLM2VecClient submits SQS jobs and polls S3 results (no HTTP)."""
from __future__ import annotations
import json
import sys
import threading
from pathlib import Path
import boto3
import pytest
from moto import mock_aws
_SERVER_ROOT = Path(__file__).resolve().parents[2]
_LAYER_ROOT = _SERVER_ROOT / "layers" / "shared" / "python"
for p in (_LAYER_ROOT, _SERVER_ROOT):
if str(p) not in sys.path:
sys.path.insert(0, str(p))
from worldmm.gpu_worker import protocol # noqa: E402
from worldmm.memory.visual.encoder import VLM2VecClient # noqa: E402
BUCKET = "test-gpu-jobs"
@pytest.fixture()
def aws(monkeypatch):
with mock_aws():
s3 = boto3.client("s3", region_name="us-east-1")
sqs = boto3.client("sqs", region_name="us-east-1")
s3.create_bucket(Bucket=BUCKET)
queue_url = sqs.create_queue(QueueName="gpu-requests")["QueueUrl"]
monkeypatch.setenv("GPU_QUEUE_URL", queue_url)
monkeypatch.setenv("GPU_JOBS_BUCKET", BUCKET)
yield s3, sqs, queue_url
def _answer_next_job(s3, sqs, queue_url, result: dict) -> None:
"""Play the worker: consume one message and write the result."""
def loop():
for _ in range(100):
msgs = sqs.receive_message(QueueUrl=queue_url, WaitTimeSeconds=1).get("Messages", [])
if msgs:
rid = json.loads(msgs[0]["Body"])["request_id"]
protocol.write_result(s3, bucket=BUCKET, request_id=rid, result=result)
return
threading.Thread(target=loop, daemon=True).start()
def test_encode_text_roundtrip(aws, monkeypatch):
s3, sqs, queue_url = aws
client = VLM2VecClient(instance_id="i-fake")
monkeypatch.setattr(client, "_ensure_running", lambda: None)
monkeypatch.setattr(client, "_poll_interval_s", 0.05)
_answer_next_job(s3, sqs, queue_url, {"embedding": [0.1, 0.2]})
assert client.encode_text("hello") == [0.1, 0.2]
# Payload made it into the request object
msgs = s3.list_objects_v2(Bucket=BUCKET, Prefix=protocol.REQUESTS_PREFIX)
body = json.loads(
s3.get_object(Bucket=BUCKET, Key=msgs["Contents"][0]["Key"])["Body"].read()
)
assert body == {"op": "encode_text", "payload": {"text": "hello"}}
def test_caption_roundtrip(aws, monkeypatch):
s3, sqs, queue_url = aws
client = VLM2VecClient(instance_id="i-fake")
monkeypatch.setattr(client, "_ensure_running", lambda: None)
monkeypatch.setattr(client, "_poll_interval_s", 0.05)
_answer_next_job(s3, sqs, queue_url, {"caption": "a sunny park"})
assert client.caption(["frame1"], transcript="birds") == "a sunny park"
def test_worker_error_raises(aws, monkeypatch):
s3, sqs, queue_url = aws
client = VLM2VecClient(instance_id="i-fake")
monkeypatch.setattr(client, "_ensure_running", lambda: None)
monkeypatch.setattr(client, "_poll_interval_s", 0.05)
monkeypatch.setattr(client, "_job_timeout_s", lambda op: 2)
def answer_with_error():
for _ in range(100):
msgs = sqs.receive_message(QueueUrl=queue_url, WaitTimeSeconds=1).get("Messages", [])
if msgs:
rid = json.loads(msgs[0]["Body"])["request_id"]
protocol.write_result(s3, bucket=BUCKET, request_id=rid, error="OOM")
return
threading.Thread(target=answer_with_error, daemon=True).start()
with pytest.raises(protocol.GpuJobError, match="OOM"):
client.encode_text("boom")
- [ ] Step 2: Run tests to verify they fail
Run: cd main/server && .venv/bin/pytest tests/unit/test_vlm2vec_client_sqs.py -v Expected: FAIL (client still does HTTP; no _poll_interval_s attribute)
- [ ] Step 3: Rewrite the transport in encoder.py
Keep everything about EC2 lifecycle (_start_or_create, _create_from_template, _find_existing_by_tag, _persist_instance_id_to_ssm) — only the probe and transport change. Replace the class header, health checks, and the four request methods:
# top of encoder.py — replace `import requests` usage for transport
from __future__ import annotations
import logging
import os
import time
import boto3
from worldmm.gpu_worker import protocol
logger = logging.getLogger(__name__)
STARTUP_TIMEOUT_S = 300 # instance boot (~2 min AMI) + model load, with margin
HEARTBEAT_FRESH_S = 90 # worker heartbeats every 30s; 3 missed = dead
JOB_TIMEOUTS_S = {
"encode_video": 300,
"encode_text": 240,
"caption": 420,
"generate": 300,
}
class VLM2VecClient:
def __init__(
self,
base_url: str | None = None, # deprecated, ignored (SQS transport)
instance_id: str | None = None,
aws_region: str = "us-east-1",
) -> None:
self._instance_id = instance_id
self._aws_region = aws_region
self._queue_url = os.environ["GPU_QUEUE_URL"]
self._bucket = os.environ.get("GPU_JOBS_BUCKET", "encache-raw-memory")
self._s3 = boto3.client("s3", region_name=aws_region)
self._sqs = boto3.client("sqs", region_name=aws_region)
self._poll_interval_s = 1.0
def _job_timeout_s(self, op: str) -> float:
return JOB_TIMEOUTS_S[op]
# ── liveness ────────────────────────────────────────────
def _is_healthy(self) -> bool:
age = protocol.heartbeat_age_s(self._s3, bucket=self._bucket)
return age is not None and age < HEARTBEAT_FRESH_S
def _wait_for_health(self) -> None:
deadline = time.monotonic() + STARTUP_TIMEOUT_S
while time.monotonic() < deadline:
if self._is_healthy():
logger.info("GPU worker heartbeat fresh — ready")
return
time.sleep(5)
raise RuntimeError(
f"GPU worker heartbeat not fresh within {STARTUP_TIMEOUT_S}s"
)
# _ensure_running / _start_or_create / _create_from_template /
# _find_existing_by_tag / _persist_instance_id_to_ssm: UNCHANGED,
# except: delete _update_base_url and every call to it (the client
# no longer needs the instance IP).
# ── transport ───────────────────────────────────────────
def _submit(self, op: str, payload: dict) -> dict:
self._ensure_running()
request_id = protocol.submit_job(
self._sqs, self._s3,
queue_url=self._queue_url, bucket=self._bucket,
op=op, payload=payload,
)
return protocol.poll_result(
self._s3, bucket=self._bucket, request_id=request_id,
timeout_s=self._job_timeout_s(op), interval_s=self._poll_interval_s,
)
def encode_video(self, frames_b64: list[str]) -> list[float]:
return self._submit("encode_video", {"frames": frames_b64})["embedding"]
def encode_text(self, text: str) -> list[float]:
return self._submit("encode_text", {"text": text})["embedding"]
def caption(self, frames_b64: list[str], transcript: str = "") -> str:
return self._submit("caption", {"frames": frames_b64, "transcript": transcript})["caption"]
def generate(self, messages: list[dict], max_new_tokens: int = 512) -> str:
return self._submit(
"generate", {"messages": messages, "max_new_tokens": max_new_tokens}
)["text"]
Delete: _auth_headers, HEALTH_TIMEOUT_S/HEALTH_POLL_INTERVAL_S constants, _update_base_url, the requests import if nothing else uses it, and the old HTTP bodies of the four methods. In _ensure_running, the "probe base_url before bootstrapping" branch becomes a plain heartbeat check (if self._is_healthy(): return).
- [ ] Step 4: Run new tests + existing suites
Run: cd main/server && .venv/bin/pytest tests/unit/test_vlm2vec_client_sqs.py tests/unit/ -v --tb=short Expected: new tests PASS. Existing tests that construct VLM2VecClient without GPU_QUEUE_URL in env will now KeyError — fix those tests by setting monkeypatch.setenv("GPU_QUEUE_URL", "https://sqs.fake/q") (integration tests mock the whole client class, so most are unaffected; _run_ingest in tests/integration/test_ingest_segment_first.py already patches worldmm.pipeline.ingest_window.VLM2VecClient).
- [ ] Step 5: Check the call sites still make sense
grep -n "VLM2VecClient(" main/server -r. All 6 call sites pass base_url=/instance_id= kwargs — they compile unchanged because base_url is still accepted. In ingest_window.py, _resolve_gpu_url() is now only used to build an ignored argument — leave the simplification (deleting _resolve_gpu_url and the base_url= args) as an optional follow-up commit if tests stay green; it is not required for cutover.
- [ ] Step 6: Commit
git add main/server/worldmm/memory/visual/encoder.py main/server/tests/unit/test_vlm2vec_client_sqs.py
git commit -m "feat(gpu): switch VLM2VecClient transport from HTTP to SQS claim-check"
Task 5: Wiring — Lambda env/policies, EC2 user_data, worker file upload
Files: - Modify: main/server/template.yaml - Modify: main/server/worldmm/gpu_worker/ec2_user_data.sh - Modify: the script that uploads server.py to s3://encache-raw-memory/gpu-worker/ (find with grep -rn "gpu-worker/server.py" scripts/ main/; likely scripts/build-gpu-ami.sh or a deploy script)
Interfaces: - Consumes: SSM /encache/gpu/queue_url (Task 1); sqs_worker.py, protocol.py (Tasks 2–3). - Produces: Lambdas get GPU_QUEUE_URL env + SQS send / S3 gpu-jobs permissions; EC2 boots sqs_worker.py under systemd.
- [ ] Step 1: template.yaml — global env var
In the global Lambda environment block (where GPU_WORKER_TOKEN currently lives), add:
- [ ] Step 2: template.yaml — per-function policies
Every function that constructs VLM2VecClient (ingest window, ingest session, chat, retriever-hosting functions — cross-check the 6 call-site files against their function definitions) needs:
- SQSSendMessagePolicy:
QueueName: encache-gpu-requests
- S3CrudPolicy:
BucketName: encache-raw-memory
Most already have S3CrudPolicy (or broader) on encache-raw-memory for frames — verify rather than duplicate.
- [ ] Step 3: Validate SAM template
Run: cd main/server && sam validate --lint Expected: template is valid
- [ ] Step 4: ec2_user_data.sh — fetch new files, new env, new ExecStart
Replace the GPU_WORKER_TOKEN SSM fetch (line ~18) with:
export GPU_QUEUE_URL=$(aws ssm get-parameter --name "/encache/gpu/queue_url" --query 'Parameter.Value' --output text --cli-connect-timeout 10 --cli-read-timeout 10)
After the existing server.py copy (line ~69), add:
aws s3 cp s3://encache-raw-memory/gpu-worker/protocol.py /opt/vlm2vec/protocol.py
aws s3 cp s3://encache-raw-memory/gpu-worker/sqs_worker.py /opt/vlm2vec/sqs_worker.py
In start.sh heredoc, change:
In the systemd unit, replace Environment="GPU_WORKER_TOKEN=${GPU_WORKER_TOKEN}" with:
- [ ] Step 5: Update the upload script
Wherever server.py is uploaded to s3://encache-raw-memory/gpu-worker/, also upload protocol.py and sqs_worker.py:
aws s3 cp main/server/worldmm/gpu_worker/protocol.py s3://encache-raw-memory/gpu-worker/protocol.py
aws s3 cp main/server/worldmm/gpu_worker/sqs_worker.py s3://encache-raw-memory/gpu-worker/sqs_worker.py
- [ ] Step 6: Commit
git add main/server/template.yaml main/server/worldmm/gpu_worker/ec2_user_data.sh scripts/
git commit -m "feat(infra): wire GPU_QUEUE_URL through SAM, user_data, and worker upload"
Task 6: Close the port and retire the token
Only merge/apply this after Tasks 1–5 are deployed and one end-to-end ingest has succeeded via SQS (see Rollout below).
Files: - Modify: main/devops/main.tf - Modify: main/devops/variables.tf - Modify: main/server/template.yaml - Modify: main/server/worldmm/gpu_worker/server.py - Modify: docs/docs/ (whichever page documents the GPU worker HTTP API — find with grep -rn "8000\|GPU_WORKER_TOKEN" docs/docs/)
- [ ] Step 1: Remove the ingress rule
Delete the ingress block from aws_security_group.gpu_worker in main/devops/main.tf (~lines 205-211) and update the SG description to "GPU worker — egress only (SQS pull model)". Keep the egress block.
-
[ ] Step 2: Remove token infrastructure
-
main/devops/main.tf: deleteaws_ssm_parameter.gpu_worker_token(~line 408). main/devops/variables.tf: deletevariable "gpu_worker_token".main/devops/terraform.tfvars(local, gitignored): delete thegpu_worker_tokenline.main/server/template.yaml: delete theGPU_WORKER_TOKENglobal env var.main/server/worldmm/gpu_worker/server.py: delete_GPU_WORKER_TOKEN,_check_token, theDepends/Headerimports if now unused, and thedependencies=[Depends(_check_token)]arguments on the four routes. The FastAPI app remains only as a local-dev harness; add a module docstring line saying so.-
main/server/worldmm/memory/visual/encoder.py: confirm_auth_headerswas already removed in Task 4. -
[ ] Step 3: Validate everything
Run: cd main/devops && terraform validate && cd ../server && sam validate --lint && .venv/bin/pytest tests/unit/ -q Expected: all green
- [ ] Step 4: Update docs
Update the GPU worker page under docs/docs/ to describe the SQS pull architecture (queue names, S3 prefixes, heartbeat, no inbound ports). Delete references to port 8000 and the worker token.
- [ ] Step 5: Commit
git add main/devops/ main/server/template.yaml main/server/worldmm/gpu_worker/server.py docs/docs/
git commit -m "feat(infra): close GPU worker port 8000 and retire worker token — SQS pull only"
Rollout (manual, in order)
terraform apply(Task 1 resources) — queue exists, port still open, nothing breaks.aws_launch_template.gpu_workerinmain/devops/main.tfdoes not manageuser_data— Terraform will never pushec2_user_data.shchanges to the launch template. Create a new launch template version with the updated script embedded, then make it the default, e.g.:Without this step, terminating the instance (next step) boots a fresh instance from the old default launch template version — oldLT_ID=$(aws ec2 describe-launch-templates --launch-template-names encache-gpu-worker --query 'LaunchTemplates[0].LaunchTemplateId' --output text) NEW_VERSION=$(aws ec2 create-launch-template-version \ --launch-template-id "$LT_ID" \ --source-version '$Latest' \ --launch-template-data "{\"UserData\":\"$(base64 -i main/server/worldmm/gpu_worker/ec2_user_data.sh)\"}" \ --query 'LaunchTemplateVersion.VersionNumber' --output text) aws ec2 modify-launch-template --launch-template-id "$LT_ID" --default-version "$NEW_VERSION"user_data, old HTTP server, silently wrong.- Upload
protocol.py+sqs_worker.pytos3://encache-raw-memory/gpu-worker/(Task 5 script). Terminate the current GPU instance so the next boot picks up the new user_data/systemd config (or restart the systemd unit manually after copying files). sam deploy(Tasks 4–5 code) — Lambdas now submit via SQS. Brief window where an in-flight HTTP-era Lambda container could still try HTTP; acceptable (single-user product, retries via DLQ).- Trigger one ingest end-to-end; confirm a segment completes and
gpu-jobs/results/objects appear. - Merge +
terraform applyTask 6 — port closed, token gone. - Reply to cubic P0 comment on PR #548 linking this plan / the follow-up PR.
Verification checklist
- [ ]
aws ec2 describe-security-groups --group-names gpu-worker-sgshows no ingress rules. - [ ] Chat question round-trips (exercises
encode_textover SQS). - [ ] Video ingest completes (exercises
caption,generate,encode_video). - [ ] Idle worker still self-stops after 1h (watchdog unchanged); next request cold-starts it via heartbeat-miss → EC2 start.
- [ ] DLQ empty after a day of normal use.