224 lines
7.5 KiB
Python
224 lines
7.5 KiB
Python
import base64
|
|
import json
|
|
from typing import Optional
|
|
|
|
import httpx
|
|
from fastapi import APIRouter, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
from config import get_settings
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class LocationOption(BaseModel):
|
|
id: str
|
|
name: str
|
|
|
|
|
|
class LabelOption(BaseModel):
|
|
id: str
|
|
name: str
|
|
|
|
|
|
class AnalyzeRequest(BaseModel):
|
|
image: str # Base64 encoded image data
|
|
locations: list[LocationOption] = []
|
|
labels: list[LabelOption] = []
|
|
|
|
|
|
class AnalyzeResponse(BaseModel):
|
|
name: str
|
|
description: str
|
|
existing_labels: list[str] # Labels that exist in Homebox
|
|
new_labels: list[str] # AI-suggested labels that don't exist yet
|
|
suggested_location: str # Location name from available locations
|
|
# Additional fields
|
|
manufacturer: str = ""
|
|
model_number: str = ""
|
|
serial_number: str = ""
|
|
quantity: int = 1
|
|
purchase_price: Optional[float] = None
|
|
notes: str = ""
|
|
raw_response: Optional[str] = None
|
|
|
|
|
|
def build_prompt(locations: list[LocationOption], labels: list[LabelOption]) -> str:
|
|
"""Build prompt with actual available locations and labels."""
|
|
location_names = [loc.name for loc in locations] if locations else ["Storage"]
|
|
label_names = [lbl.name for lbl in labels] if labels else []
|
|
|
|
prompt = """Analyze this image for a home inventory system. Identify the item and extract as much information as possible.
|
|
|
|
Respond with ONLY a JSON object containing these fields (use null for unknown values):
|
|
|
|
{
|
|
"name": "Descriptive item name (include brand if visible)",
|
|
"description": "Brief description - color, size, condition, features",
|
|
"manufacturer": "Brand/manufacturer name if visible",
|
|
"model": "Model number if visible",
|
|
"serial": "Serial number if visible",
|
|
"quantity": 1,
|
|
"price": null,
|
|
"labels": ["category1", "category2"],
|
|
"location": "storage location",
|
|
"notes": "Any other relevant details"
|
|
}
|
|
|
|
AVAILABLE LOCATIONS (pick one exactly as written):
|
|
"""
|
|
prompt += "\n".join(f"- {name}" for name in location_names)
|
|
|
|
if label_names:
|
|
prompt += "\n\nEXISTING LABELS (use these exact names if they match):\n"
|
|
prompt += "\n".join(f"- {name}" for name in label_names)
|
|
prompt += "\n\nFor labels: Pick from existing labels if they fit. If no existing label matches, create a new short label name (lowercase, hyphenated like 'power-tools' or 'kitchen-appliances')."
|
|
else:
|
|
prompt += "\n\nNo labels exist yet. Suggest 1-3 short category labels (lowercase, hyphenated like 'electronics' or 'hand-tools')."
|
|
|
|
prompt += "\n\nRespond with ONLY the JSON object, no other text."
|
|
return prompt
|
|
|
|
|
|
def parse_response(response_text: str, locations: list[LocationOption], labels: list[LabelOption]) -> dict:
|
|
"""Parse AI response and extract all fields."""
|
|
# Extract JSON from response
|
|
data = {}
|
|
try:
|
|
start = response_text.find('{')
|
|
end = response_text.rfind('}')
|
|
if start != -1 and end != -1:
|
|
data = json.loads(response_text[start:end + 1])
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
# Get basic fields
|
|
name = str(data.get("name", "Unknown Item"))
|
|
description = str(data.get("description", ""))
|
|
|
|
# Get additional fields
|
|
manufacturer = str(data.get("manufacturer") or data.get("brand") or "")
|
|
model_number = str(data.get("model") or data.get("model_number") or data.get("modelNumber") or "")
|
|
serial_number = str(data.get("serial") or data.get("serial_number") or data.get("serialNumber") or "")
|
|
notes = str(data.get("notes") or "")
|
|
|
|
# Quantity
|
|
try:
|
|
quantity = int(data.get("quantity", 1))
|
|
if quantity < 1:
|
|
quantity = 1
|
|
except (ValueError, TypeError):
|
|
quantity = 1
|
|
|
|
# Price
|
|
purchase_price = None
|
|
price_val = data.get("price") or data.get("purchase_price") or data.get("purchasePrice")
|
|
if price_val is not None:
|
|
try:
|
|
purchase_price = float(price_val)
|
|
except (ValueError, TypeError):
|
|
pass
|
|
|
|
# Validate location against available options
|
|
location_names_map = {loc.name.lower(): loc.name for loc in locations}
|
|
suggested_loc = str(data.get("location", ""))
|
|
if suggested_loc.lower() in location_names_map:
|
|
location = location_names_map[suggested_loc.lower()]
|
|
elif locations:
|
|
location = locations[0].name
|
|
else:
|
|
location = ""
|
|
|
|
# Separate existing labels from new suggestions
|
|
existing_label_map = {lbl.name.lower(): lbl.name for lbl in labels}
|
|
existing_labels = []
|
|
new_labels = []
|
|
|
|
raw_labels = data.get("labels", data.get("tags", []))
|
|
if isinstance(raw_labels, list):
|
|
for lbl in raw_labels:
|
|
lbl_str = str(lbl).strip()
|
|
if not lbl_str:
|
|
continue
|
|
lbl_lower = lbl_str.lower()
|
|
if lbl_lower in existing_label_map:
|
|
existing_labels.append(existing_label_map[lbl_lower])
|
|
else:
|
|
# Clean up new label suggestion
|
|
new_label = lbl_lower.replace(' ', '-')
|
|
if new_label and new_label not in [l.lower() for l in new_labels]:
|
|
new_labels.append(new_label)
|
|
|
|
return {
|
|
"name": name,
|
|
"description": description,
|
|
"existing_labels": existing_labels,
|
|
"new_labels": new_labels,
|
|
"location": location,
|
|
"manufacturer": manufacturer,
|
|
"model_number": model_number,
|
|
"serial_number": serial_number,
|
|
"quantity": quantity,
|
|
"purchase_price": purchase_price,
|
|
"notes": notes,
|
|
}
|
|
|
|
|
|
@router.post("/analyze", response_model=AnalyzeResponse)
|
|
async def analyze_image(request: AnalyzeRequest):
|
|
"""Send image to Ollama for analysis and return suggestions."""
|
|
settings = get_settings()
|
|
|
|
# Strip data URL prefix if present
|
|
image_data = request.image
|
|
if "," in image_data:
|
|
image_data = image_data.split(",", 1)[1]
|
|
|
|
# Validate base64
|
|
try:
|
|
base64.b64decode(image_data)
|
|
except Exception:
|
|
raise HTTPException(status_code=400, detail="Invalid base64 image data")
|
|
|
|
# Build prompt with available options
|
|
prompt = build_prompt(request.locations, request.labels)
|
|
|
|
# Call Ollama API
|
|
ollama_url = f"{settings.ollama_url}/api/generate"
|
|
payload = {
|
|
"model": settings.ollama_model,
|
|
"prompt": prompt,
|
|
"images": [image_data],
|
|
"stream": False,
|
|
}
|
|
|
|
try:
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(ollama_url, json=payload)
|
|
response.raise_for_status()
|
|
result = response.json()
|
|
except httpx.TimeoutException:
|
|
raise HTTPException(status_code=504, detail="Ollama request timed out")
|
|
except httpx.RequestError as e:
|
|
raise HTTPException(status_code=502, detail=f"Failed to connect to Ollama: {str(e)}")
|
|
except httpx.HTTPStatusError as e:
|
|
raise HTTPException(status_code=502, detail=f"Ollama error: {e.response.text}")
|
|
|
|
raw_response = result.get("response", "")
|
|
parsed = parse_response(raw_response, request.locations, request.labels)
|
|
|
|
return AnalyzeResponse(
|
|
name=parsed["name"],
|
|
description=parsed["description"],
|
|
existing_labels=parsed["existing_labels"],
|
|
new_labels=parsed["new_labels"],
|
|
suggested_location=parsed["location"],
|
|
manufacturer=parsed["manufacturer"],
|
|
model_number=parsed["model_number"],
|
|
serial_number=parsed["serial_number"],
|
|
quantity=parsed["quantity"],
|
|
purchase_price=parsed["purchase_price"],
|
|
notes=parsed["notes"],
|
|
raw_response=raw_response,
|
|
)
|