homebox_ai_fronted/backend/routers/analyze.py

224 lines
7.5 KiB
Python

import base64
import json
from typing import Optional
import httpx
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from config import get_settings
router = APIRouter()
class LocationOption(BaseModel):
id: str
name: str
class LabelOption(BaseModel):
id: str
name: str
class AnalyzeRequest(BaseModel):
image: str # Base64 encoded image data
locations: list[LocationOption] = []
labels: list[LabelOption] = []
class AnalyzeResponse(BaseModel):
name: str
description: str
existing_labels: list[str] # Labels that exist in Homebox
new_labels: list[str] # AI-suggested labels that don't exist yet
suggested_location: str # Location name from available locations
# Additional fields
manufacturer: str = ""
model_number: str = ""
serial_number: str = ""
quantity: int = 1
purchase_price: Optional[float] = None
notes: str = ""
raw_response: Optional[str] = None
def build_prompt(locations: list[LocationOption], labels: list[LabelOption]) -> str:
"""Build prompt with actual available locations and labels."""
location_names = [loc.name for loc in locations] if locations else ["Storage"]
label_names = [lbl.name for lbl in labels] if labels else []
prompt = """Analyze this image for a home inventory system. Identify the item and extract as much information as possible.
Respond with ONLY a JSON object containing these fields (use null for unknown values):
{
"name": "Descriptive item name (include brand if visible)",
"description": "Brief description - color, size, condition, features",
"manufacturer": "Brand/manufacturer name if visible",
"model": "Model number if visible",
"serial": "Serial number if visible",
"quantity": 1,
"price": null,
"labels": ["category1", "category2"],
"location": "storage location",
"notes": "Any other relevant details"
}
AVAILABLE LOCATIONS (pick one exactly as written):
"""
prompt += "\n".join(f"- {name}" for name in location_names)
if label_names:
prompt += "\n\nEXISTING LABELS (use these exact names if they match):\n"
prompt += "\n".join(f"- {name}" for name in label_names)
prompt += "\n\nFor labels: Pick from existing labels if they fit. If no existing label matches, create a new short label name (lowercase, hyphenated like 'power-tools' or 'kitchen-appliances')."
else:
prompt += "\n\nNo labels exist yet. Suggest 1-3 short category labels (lowercase, hyphenated like 'electronics' or 'hand-tools')."
prompt += "\n\nRespond with ONLY the JSON object, no other text."
return prompt
def parse_response(response_text: str, locations: list[LocationOption], labels: list[LabelOption]) -> dict:
"""Parse AI response and extract all fields."""
# Extract JSON from response
data = {}
try:
start = response_text.find('{')
end = response_text.rfind('}')
if start != -1 and end != -1:
data = json.loads(response_text[start:end + 1])
except json.JSONDecodeError:
pass
# Get basic fields
name = str(data.get("name", "Unknown Item"))
description = str(data.get("description", ""))
# Get additional fields
manufacturer = str(data.get("manufacturer") or data.get("brand") or "")
model_number = str(data.get("model") or data.get("model_number") or data.get("modelNumber") or "")
serial_number = str(data.get("serial") or data.get("serial_number") or data.get("serialNumber") or "")
notes = str(data.get("notes") or "")
# Quantity
try:
quantity = int(data.get("quantity", 1))
if quantity < 1:
quantity = 1
except (ValueError, TypeError):
quantity = 1
# Price
purchase_price = None
price_val = data.get("price") or data.get("purchase_price") or data.get("purchasePrice")
if price_val is not None:
try:
purchase_price = float(price_val)
except (ValueError, TypeError):
pass
# Validate location against available options
location_names_map = {loc.name.lower(): loc.name for loc in locations}
suggested_loc = str(data.get("location", ""))
if suggested_loc.lower() in location_names_map:
location = location_names_map[suggested_loc.lower()]
elif locations:
location = locations[0].name
else:
location = ""
# Separate existing labels from new suggestions
existing_label_map = {lbl.name.lower(): lbl.name for lbl in labels}
existing_labels = []
new_labels = []
raw_labels = data.get("labels", data.get("tags", []))
if isinstance(raw_labels, list):
for lbl in raw_labels:
lbl_str = str(lbl).strip()
if not lbl_str:
continue
lbl_lower = lbl_str.lower()
if lbl_lower in existing_label_map:
existing_labels.append(existing_label_map[lbl_lower])
else:
# Clean up new label suggestion
new_label = lbl_lower.replace(' ', '-')
if new_label and new_label not in [l.lower() for l in new_labels]:
new_labels.append(new_label)
return {
"name": name,
"description": description,
"existing_labels": existing_labels,
"new_labels": new_labels,
"location": location,
"manufacturer": manufacturer,
"model_number": model_number,
"serial_number": serial_number,
"quantity": quantity,
"purchase_price": purchase_price,
"notes": notes,
}
@router.post("/analyze", response_model=AnalyzeResponse)
async def analyze_image(request: AnalyzeRequest):
"""Send image to Ollama for analysis and return suggestions."""
settings = get_settings()
# Strip data URL prefix if present
image_data = request.image
if "," in image_data:
image_data = image_data.split(",", 1)[1]
# Validate base64
try:
base64.b64decode(image_data)
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 image data")
# Build prompt with available options
prompt = build_prompt(request.locations, request.labels)
# Call Ollama API
ollama_url = f"{settings.ollama_url}/api/generate"
payload = {
"model": settings.ollama_model,
"prompt": prompt,
"images": [image_data],
"stream": False,
}
try:
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(ollama_url, json=payload)
response.raise_for_status()
result = response.json()
except httpx.TimeoutException:
raise HTTPException(status_code=504, detail="Ollama request timed out")
except httpx.RequestError as e:
raise HTTPException(status_code=502, detail=f"Failed to connect to Ollama: {str(e)}")
except httpx.HTTPStatusError as e:
raise HTTPException(status_code=502, detail=f"Ollama error: {e.response.text}")
raw_response = result.get("response", "")
parsed = parse_response(raw_response, request.locations, request.labels)
return AnalyzeResponse(
name=parsed["name"],
description=parsed["description"],
existing_labels=parsed["existing_labels"],
new_labels=parsed["new_labels"],
suggested_location=parsed["location"],
manufacturer=parsed["manufacturer"],
model_number=parsed["model_number"],
serial_number=parsed["serial_number"],
quantity=parsed["quantity"],
purchase_price=parsed["purchase_price"],
notes=parsed["notes"],
raw_response=raw_response,
)