Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 44 additions & 13 deletions download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,72 @@

import os
import sys
import shutil
from pathlib import Path

def download_phi2_model():
"""Download Phi-2 model to local models directory."""
try:
from transformers import AutoTokenizer, AutoModelForCausalLM

print("🔍 Checking dependencies...")
try:
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
print(f"✅ Transformers version: {transformers.__version__}")
print(f"✅ PyTorch version: {torch.__version__}")
except ImportError as e:
print(f"❌ Missing dependencies: {e}")
print("Please run: pip install transformers torch sentencepiece")
sys.exit(1)

model_id = "microsoft/phi-2"
models_dir = Path(__file__).parent / "models" / "phi-2"
# Determine project root
project_root = Path(__file__).parent
models_dir = project_root / "models" / "phi-2"

print(f"📂 Target directory: {models_dir}")
models_dir.mkdir(parents=True, exist_ok=True)

print(f"📥 Downloading {model_id} to {models_dir}")
# Check disk space (need ~6GB)
total, used, free = shutil.disk_usage(models_dir.parent)
free_gb = free / (1024**3)
print(f"💾 Free disk space: {free_gb:.2f} GB")
if free_gb < 7:
print("⚠️ WARNING: You have less than 7GB of free space. Download might fail.")
confirm = input("Continue anyway? (y/n): ")
if confirm.lower() != 'y':
sys.exit(0)

print(f"📥 Starting download of {model_id}...")
print(" This may take 5-10 minutes depending on your internet connection.")

# Download tokenizer
print("Downloading tokenizer...")
print("Downloading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.save_pretrained(models_dir)
print("✅ Tokenizer saved.")

# Download model
print("Downloading model (this may take a while)...")
print("⏳ Downloading model (approx 5.6GB)...")
# Use locally_files_only=False to ensure we check remote, but verify cache
model = AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=True,
torch_dtype="auto"
torch_dtype="auto",
device_map="auto" if torch.cuda.is_available() else "cpu"
)
model.save_pretrained(models_dir)

print(f"✅ Model downloaded successfully to {models_dir}")
print(f"📊 Model size: ~5.6GB")
print(f"✅ Model downloaded successfully to: {models_dir}")
print("🎉 You can now run the backend with Local LLM enabled!")

except ImportError:
print("❌ transformers not installed. Run: pip install transformers torch")
sys.exit(1)
except Exception as e:
print(f"❌ Download failed: {e}")
print(f"\n❌ CRITICAL ERROR during download:")
print(f"{str(e)}")
print("\nTroubleshooting:")
print("1. Check your internet connection")
print("2. Ensure you have enough disk space (~6GB)")
print("3. Try running: pip install --upgrade transformers torch accelerat")
sys.exit(1)

if __name__ == "__main__":
Expand Down
69 changes: 55 additions & 14 deletions form-flow-backend/services/ai/conversation_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ async def process_user_input(
# Get remaining fields and current batch
remaining_fields = session.get_remaining_fields()

# DEBUG: Log session state
logger.info(f"🔍 Session Debug - Extracted: {list(session.extracted_fields.keys())}")
logger.info(f"🔍 Session Debug - Remaining ({len(remaining_fields)}): {[f.get('name') for f in remaining_fields]}")

# Determine max fields based on client type and feature flag
# Web frontend gets grouped questions when SMART_GROUPING_ENABLED is True
is_web_client = getattr(session, 'client_type', 'extension') == 'web'
Expand Down Expand Up @@ -538,17 +542,30 @@ async def process_user_input(
logger.info(f"Processing input: '{user_input[:100]}...'")
logger.info(f"Current batch fields: {[f.get('name') for f in current_batch]}")

# Get full schema for extraction and refinement
all_fields = remaining_fields
if hasattr(session, 'form_schema') and session.form_schema:
raw_schema = session.form_schema
flattened = []
for item in raw_schema:
if isinstance(item, dict) and 'fields' in item:
flattened.extend(item['fields'])
else:
flattened.append(item)
all_fields = flattened

extracted, confidence_scores, message = await self._extract_values(
session, user_input, current_batch, remaining_fields, is_voice
)

# 8. Refine and store extracted values using atomic FormDataManager
refined = self.value_refiner.refine_values(extracted, remaining_fields)
# Use simple cleaning without heavy NLP for speed, or pass True if AI refinement desired
refined = self.value_refiner.refine_values(extracted, all_fields)

for field_name, value in refined.items():
# Get field info for pattern detection
# Get field info from ALL fields to ensure type logic applies
field_info = next(
(f for f in remaining_fields if f.get('name') == field_name),
(f for f in all_fields if f.get('name') == field_name),
{}
)

Expand Down Expand Up @@ -737,11 +754,26 @@ async def _extract_values(

# Prepare field list for extraction - use ALL remaining fields, not just current_batch
# This allows users to provide multiple fields at once (e.g., "my name is X and email is Y")
fields_to_extract = [
field.get('label', field.get('name'))
for field in remaining_fields # Use remaining_fields, not current_batch
if field.get('name') not in extracted
]
# Prepare field list for extraction - pass FULL objects for context-aware extraction
# This allows the LLM to see options, types, and labels
# CHANGE: Use session.form_schema (ALL fields) instead of remaining_fields
# This enables "Smart Update/Correction" where the user can update any field value at any time
if hasattr(session, 'form_schema') and session.form_schema:
raw_schema = session.form_schema
# Flatten schema: If it's a list provides forms with 'fields', extract them
all_fields = []
for item in raw_schema:
if isinstance(item, dict) and 'fields' in item:
all_fields.extend(item['fields'])
else:
all_fields.append(item)
fields_to_extract = all_fields
else:
fields_to_extract = remaining_fields

# If form_schema is somehow empty (shouldn't be), fall back
if not fields_to_extract:
fields_to_extract = remaining_fields

logger.info(f"Extracting from {len(fields_to_extract)} fields: {fields_to_extract[:5]}...")

Expand All @@ -756,32 +788,41 @@ async def _extract_values(
logger.info(f"Local LLM raw extraction: {new_extracted}")

# Update main extracted dict
for label, value in new_extracted.items():
# Find matching field name from label - search in remaining_fields
# Update main extracted dict
for key, value in new_extracted.items():
# Find matching field name from key - search in fields_to_extract (which effectively covers everything)
# This ensures we match against any field in the schema, not just remaining ones
field_match = next(
(f for f in remaining_fields if f.get('label', f.get('name')) == label),
(f for f in fields_to_extract if f.get('name') == key or f.get('label') == key),
None
)
if field_match:
field_name = field_match.get('name')
extracted[field_name] = value
confidence[field_name] = new_confidence.get(label, 0.8)
confidence[field_name] = new_confidence.get(key, 0.8)

if extracted:
logger.info(f"Local LLM extracted: {list(extracted.keys())}")

# Generate confirmation message
extracted_labels = []
for field_name in extracted.keys():
field_info = next((f for f in current_batch if f.get('name') == field_name), {})
# Look up label in fields_to_extract to ensure we get labels for ANY field (even past ones)
field_info = next((f for f in fields_to_extract if f.get('name') == field_name), {})
extracted_labels.append(field_info.get('label', field_name))

if len(extracted_labels) == 1:
message = f"Got your {extracted_labels[0]}!"
else:
message = f"Got your {', '.join(extracted_labels)}!"
# Limit to first 3 to avoid super long messages
if len(extracted_labels) > 3:
message = f"Got your {', '.join(extracted_labels[:3])} and others!"
else:
message = f"Got your {', '.join(extracted_labels)}!"

# Add next question if more fields remain
# We accept that 'extracted' might contain fields NOT in 'remaining_fields' (updates)
# So we just filter remaining_fields by what is now extracted
remaining_after = [f for f in remaining_fields if f.get('name') not in extracted]
if remaining_after:
next_batches = self.clusterer.create_batches(remaining_after)
Expand Down
4 changes: 2 additions & 2 deletions form-flow-backend/services/ai/extraction/value_refiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,8 @@ def _refine_against_options(self, value: str, options: List[str]) -> Optional[st

for opt in options:
if isinstance(opt, dict):
# Prefer label for matching user speech
opt_str = opt.get('label') or opt.get('name') or opt.get('value') or str(opt)
# Prefer label/text for matching user speech
opt_str = opt.get('label') or opt.get('text') or opt.get('name') or opt.get('value') or str(opt)
else:
opt_str = str(opt)

Expand Down
Loading
Loading