174 lines
6.0 KiB
Python
174 lines
6.0 KiB
Python
import base64
|
|
import json
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from config import get_settings
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class LocationOption(BaseModel):
|
|
id: str
|
|
name: str
|
|
|
|
|
|
class LabelOption(BaseModel):
|
|
id: str
|
|
name: str
|
|
|
|
|
|
class AnalyzeRequest(BaseModel):
|
|
image: str # Base64 encoded image data
|
|
locations: list[LocationOption] = []
|
|
labels: list[LabelOption] = []
|
|
|
|
|
|
class AnalyzeResponse(BaseModel):
|
|
name: str
|
|
description: str
|
|
existing_labels: list[str] # Labels that exist in Homebox
|
|
new_labels: list[str] # AI-suggested labels that don't exist yet
|
|
suggested_location: str # Location name from available locations
|
|
raw_response: Optional[str] = None
|
|
|
|
|
|
def build_prompt(locations: list[LocationOption], labels: list[LabelOption]) -> str:
|
|
"""Build prompt with actual available locations and labels."""
|
|
location_names = [loc.name for loc in locations] if locations else ["Storage"]
|
|
label_names = [lbl.name for lbl in labels] if labels else []
|
|
|
|
prompt = """Look at this image and identify the item for a home inventory system.
|
|
|
|
Respond with ONLY a JSON object (no other text) with these fields:
|
|
|
|
- name: Short descriptive name for the item (e.g. "Red Handled Scissors", "DeWalt Cordless Drill")
|
|
- description: Brief description - color, brand, condition, notable features
|
|
- labels: Array of 1-4 relevant category labels for this item
|
|
- location: Best storage location from the AVAILABLE LOCATIONS below (use exact name)
|
|
|
|
AVAILABLE LOCATIONS (you MUST pick one exactly as written):
|
|
"""
|
|
prompt += "\n".join(f"- {name}" for name in location_names)
|
|
|
|
if label_names:
|
|
prompt += "\n\nEXISTING LABELS (prefer these if they fit, use exact names):\n"
|
|
prompt += "\n".join(f"- {name}" for name in label_names)
|
|
prompt += "\n\nFor labels: Use existing labels above if they match. If the item needs a label that doesn't exist, suggest a new short label name (lowercase, single word or hyphenated like 'power-tools')."
|
|
else:
|
|
prompt += "\n\nNo labels exist yet. Suggest 1-3 short label names (lowercase, single word or hyphenated like 'electronics' or 'hand-tools')."
|
|
|
|
prompt += "\n\nRespond with ONLY the JSON object."
|
|
return prompt
|
|
|
|
|
|
def parse_response(response_text: str, locations: list[LocationOption], labels: list[LabelOption]) -> dict:
|
|
"""Parse AI response and separate existing vs new labels."""
|
|
# Extract JSON from response
|
|
data = {}
|
|
try:
|
|
start = response_text.find('{')
|
|
end = response_text.rfind('}')
|
|
if start != -1 and end != -1:
|
|
data = json.loads(response_text[start:end + 1])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Get name and description
|
|
name = str(data.get("name", "Unknown Item"))
|
|
description = str(data.get("description", ""))
|
|
|
|
# Validate location against available options
|
|
location_names = {loc.name.lower(): loc.name for loc in locations}
|
|
suggested_loc = str(data.get("location", ""))
|
|
if suggested_loc.lower() in location_names:
|
|
location = location_names[suggested_loc.lower()]
|
|
elif locations:
|
|
location = locations[0].name
|
|
else:
|
|
location = ""
|
|
|
|
# Separate existing labels from new suggestions
|
|
existing_label_map = {lbl.name.lower(): lbl.name for lbl in labels}
|
|
existing_labels = []
|
|
new_labels = []
|
|
|
|
raw_labels = data.get("labels", data.get("tags", []))
|
|
if isinstance(raw_labels, list):
|
|
for lbl in raw_labels:
|
|
lbl_str = str(lbl).strip()
|
|
if not lbl_str:
|
|
continue
|
|
lbl_lower = lbl_str.lower()
|
|
if lbl_lower in existing_label_map:
|
|
existing_labels.append(existing_label_map[lbl_lower])
|
|
else:
|
|
# Clean up new label suggestion
|
|
new_label = lbl_lower.replace(' ', '-')
|
|
if new_label and new_label not in [l.lower() for l in new_labels]:
|
|
new_labels.append(new_label)
|
|
|
|
return {
|
|
"name": name,
|
|
"description": description,
|
|
"existing_labels": existing_labels,
|
|
"new_labels": new_labels,
|
|
"location": location,
|
|
}
|
|
|
|
|
|
@router.post("/analyze", response_model=AnalyzeResponse)
|
|
async def analyze_image(request: AnalyzeRequest):
|
|
"""Send image to Ollama for analysis and return suggestions."""
|
|
settings = get_settings()
|
|
|
|
# Strip data URL prefix if present
|
|
image_data = request.image
|
|
if "," in image_data:
|
|
image_data = image_data.split(",", 1)[1]
|
|
|
|
# Validate base64
|
|
try:
|
|
base64.b64decode(image_data)
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="Invalid base64 image data")
|
|
|
|
# Build prompt with available options
|
|
prompt = build_prompt(request.locations, request.labels)
|
|
|
|
# Call Ollama API
|
|
ollama_url = f"{settings.ollama_url}/api/generate"
|
|
payload = {
|
|
"model": settings.ollama_model,
|
|
"prompt": prompt,
|
|
"images": [image_data],
|
|
"stream": False,
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(ollama_url, json=payload)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
except httpx.TimeoutException:
|
|
raise HTTPException(status_code=504, detail="Ollama request timed out")
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=502, detail=f"Failed to connect to Ollama: {str(e)}")
|
|
except httpx.HTTPStatusError as e:
|
|
raise HTTPException(status_code=502, detail=f"Ollama error: {e.response.text}")
|
|
|
|
raw_response = result.get("response", "")
|
|
parsed = parse_response(raw_response, request.locations, request.labels)
|
|
|
|
return AnalyzeResponse(
|
|
name=parsed["name"],
|
|
description=parsed["description"],
|
|
existing_labels=parsed["existing_labels"],
|
|
new_labels=parsed["new_labels"],
|
|
suggested_location=parsed["location"],
|
|
raw_response=raw_response,
|
|
)
|