homebox_ai_fronted/backend/routers/analyze.py

174 lines
6.0 KiB
Python

import base64
import json
from typing import Optional
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from config import get_settings
router = APIRouter()
class LocationOption(BaseModel):
id: str
name: str
class LabelOption(BaseModel):
id: str
name: str
class AnalyzeRequest(BaseModel):
image: str # Base64 encoded image data
locations: list[LocationOption] = []
labels: list[LabelOption] = []
class AnalyzeResponse(BaseModel):
name: str
description: str
existing_labels: list[str] # Labels that exist in Homebox
new_labels: list[str] # AI-suggested labels that don't exist yet
suggested_location: str # Location name from available locations
raw_response: Optional[str] = None
def build_prompt(locations: list[LocationOption], labels: list[LabelOption]) -> str:
"""Build prompt with actual available locations and labels."""
location_names = [loc.name for loc in locations] if locations else ["Storage"]
label_names = [lbl.name for lbl in labels] if labels else []
prompt = """Look at this image and identify the item for a home inventory system.
Respond with ONLY a JSON object (no other text) with these fields:
- name: Short descriptive name for the item (e.g. "Red Handled Scissors", "DeWalt Cordless Drill")
- description: Brief description - color, brand, condition, notable features
- labels: Array of 1-4 relevant category labels for this item
- location: Best storage location from the AVAILABLE LOCATIONS below (use exact name)
AVAILABLE LOCATIONS (you MUST pick one exactly as written):
"""
prompt += "\n".join(f"- {name}" for name in location_names)
if label_names:
prompt += "\n\nEXISTING LABELS (prefer these if they fit, use exact names):\n"
prompt += "\n".join(f"- {name}" for name in label_names)
prompt += "\n\nFor labels: Use existing labels above if they match. If the item needs a label that doesn't exist, suggest a new short label name (lowercase, single word or hyphenated like 'power-tools')."
else:
prompt += "\n\nNo labels exist yet. Suggest 1-3 short label names (lowercase, single word or hyphenated like 'electronics' or 'hand-tools')."
prompt += "\n\nRespond with ONLY the JSON object."
return prompt
def parse_response(response_text: str, locations: list[LocationOption], labels: list[LabelOption]) -> dict:
"""Parse AI response and separate existing vs new labels."""
# Extract JSON from response
data = {}
try:
start = response_text.find('{')
end = response_text.rfind('}')
if start != -1 and end != -1:
data = json.loads(response_text[start:end + 1])
except json.JSONDecodeError:
pass
# Get name and description
name = str(data.get("name", "Unknown Item"))
description = str(data.get("description", ""))
# Validate location against available options
location_names = {loc.name.lower(): loc.name for loc in locations}
suggested_loc = str(data.get("location", ""))
if suggested_loc.lower() in location_names:
location = location_names[suggested_loc.lower()]
elif locations:
location = locations[0].name
else:
location = ""
# Separate existing labels from new suggestions
existing_label_map = {lbl.name.lower(): lbl.name for lbl in labels}
existing_labels = []
new_labels = []
raw_labels = data.get("labels", data.get("tags", []))
if isinstance(raw_labels, list):
for lbl in raw_labels:
lbl_str = str(lbl).strip()
if not lbl_str:
continue
lbl_lower = lbl_str.lower()
if lbl_lower in existing_label_map:
existing_labels.append(existing_label_map[lbl_lower])
else:
# Clean up new label suggestion
new_label = lbl_lower.replace(' ', '-')
if new_label and new_label not in [l.lower() for l in new_labels]:
new_labels.append(new_label)
return {
"name": name,
"description": description,
"existing_labels": existing_labels,
"new_labels": new_labels,
"location": location,
}
@router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_image(request: AnalyzeRequest):
"""Send image to Ollama for analysis and return suggestions."""
settings = get_settings()
# Strip data URL prefix if present
image_data = request.image
if "," in image_data:
image_data = image_data.split(",", 1)[1]
# Validate base64
try:
base64.b64decode(image_data)
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 image data")
# Build prompt with available options
prompt = build_prompt(request.locations, request.labels)
# Call Ollama API
ollama_url = f"{settings.ollama_url}/api/generate"
payload = {
"model": settings.ollama_model,
"prompt": prompt,
"images": [image_data],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(ollama_url, json=payload)
response.raise_for_status()
result = response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Ollama request timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Failed to connect to Ollama: {str(e)}")
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e.response.text}")
raw_response = result.get("response", "")
parsed = parse_response(raw_response, request.locations, request.labels)
return AnalyzeResponse(
name=parsed["name"],
description=parsed["description"],
existing_labels=parsed["existing_labels"],
new_labels=parsed["new_labels"],
suggested_location=parsed["location"],
raw_response=raw_response,
)