Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,12 @@ jobs:
nohup uvicorn app.app:app &
sleep 10
curl -I http://127.0.0.1:8000

# Test batch prediction endpoint
echo "ra,dec,redshift,psfMag_r,u,g,r,i,z" > dummy_test.csv
echo "0.1,0.2,0.3,1.0,2.0,3.0,4.0,5.0,6.0" >> dummy_test.csv
curl -X POST http://127.0.0.1:8000/predict/file -H "accept: application/json" -F "payload=@dummy_test.csv;type=text/csv"
rm dummy_test.csv

pkill -f uvicorn

50 changes: 50 additions & 0 deletions Datasets/sample_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""
sample_generator.py

This script is used to generate samples directly from the main dataset.
These samples are used to test the `/predict/file` route.
To run it, enter this in your command line:
```
python -m Datasets.sample_generator.py
```
"""

import pandas as pd
from pathlib import Path

DATASET_PATH = Path("Datasets","SDSS_DR18.csv")
OUTPUT_PATH = Path("Datasets","samples.csv")

def preprocess(df: pd.DataFrame) -> pd.DataFrame:
drop_cols = ["objid", "specobjid", "run", "rerun", "camcol",
"field", "plate", "mjd", "fiberid"]
df = df.drop(columns=[c for c in drop_cols if c in df.columns])

df["class"] = df["class"].map({"GALAXY": 0, "STAR": 1, "QSO": 2})

df = df[["ra", "dec", "redshift","psfMag_r", "u", "g", "r", "i", "z", "class"]].copy()

return df


def stratified_sample(df: pd.DataFrame, total_samples: int, class_col: str = "class", random_state: int = 42) -> pd.DataFrame:
class_counts = df[class_col].value_counts(normalize=True)
class_n = (class_counts * total_samples).round().astype(int)

diff = total_samples - class_n.sum()
class_n.iloc[0] += diff

return df.groupby(class_col, group_keys=False).apply(
lambda x: x.sample(n=min(class_n[x.name], len(x)), random_state=random_state)
)


df_raw = pd.read_csv(DATASET_PATH)

df_processed = preprocess(df_raw)

sample = stratified_sample(df_processed, total_samples=100)

sample = sample.drop(columns=["class"])

sample.to_csv(OUTPUT_PATH,index=False)
95 changes: 83 additions & 12 deletions app/app.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
from fastapi import FastAPI, Depends, Request
from pydantic import BaseModel
from fastapi import FastAPI, Depends, Request, UploadFile
from fastapi.responses import JSONResponse, FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.exceptions import RequestValidationError
from fastapi.exceptions import RequestValidationError, HTTPException
from typing import Tuple,List
from sklearn.pipeline import Pipeline
from models.fit import main
from .schema.validation import UserInput
from pathlib import Path
from contextlib import asynccontextmanager
import joblib
import numpy as np
import pandas as pd
import joblib
import os

# Helper for loading and self-healing the artifacts
def load_or_create_models() -> Tuple[Pipeline,np.ndarray]:
model_path = Path("models","estimator.pkl")
columns_path = Path("models","column_names.pkl")
Expand All @@ -34,6 +37,36 @@ def load_or_create_models() -> Tuple[Pipeline,np.ndarray]:
return pipe,column_names
except Exception as e:
raise RuntimeError(f"Artifacts could not be loaded: {e}")

# Helper for performing feature engineering
def preprocess_data(value:BaseModel) -> dict:
# Preprocessing
value:dict = value.model_dump(mode="json")
kick = ["u","g","r","i","z"]
final_value = {key:val for key,val in value.items() if key not in kick}
final_value["u_g_color"] = safe_sub("u","g",value)
final_value["g_r_color"] = safe_sub("g","r",value)
final_value["r_i_color"] = safe_sub("r","i",value)
final_value["i_z_color"] = safe_sub("i","z",value)

return final_value

# Helper for validating user-provided csv files
def upload_validator(df:pd.DataFrame,col_names:List[str]) -> pd.DataFrame:
if df.columns.tolist() != col_names:
raise HTTPException(
status_code=422, detail="Uploaded csv file does not match the expected " \
"columns or their order"
)

try:
df = df.astype(float)
except Exception as e:
raise HTTPException(
status_code=422,
detail="All values must be numeric (float-compatible)"
)
return df

@asynccontextmanager
async def lifespan(app:FastAPI):
Expand Down Expand Up @@ -91,15 +124,7 @@ def home():
def prediction_ops(value:UserInput, dep:Tuple[Pipeline,np.ndarray] = Depends(get_model)):
pipe, column_names = dep
column_names:List[str] = column_names.tolist()

# Preprocessing
value:dict = value.model_dump(mode="json")
kick = ["u","g","r","i","z"]
final_value = {key:val for key,val in value.items() if key not in kick}
final_value["u_g_color"] = safe_sub("u","g",value)
final_value["g_r_color"] = safe_sub("g","r",value)
final_value["r_i_color"] = safe_sub("r","i",value)
final_value["i_z_color"] = safe_sub("i","z",value)
final_value:dict = preprocess_data(value)

# Order Check and running prediction
final_res = []
Expand Down Expand Up @@ -127,3 +152,49 @@ def prediction_ops(value:UserInput, dep:Tuple[Pipeline,np.ndarray] = Depends(get
status_code=201, content=msg
)

@app.post("/predict/file")
async def prediction_via_file_ops(payload:UploadFile, dep: Tuple[Pipeline,np.ndarray] = Depends(get_model)):
pipe, column_names = dep
expected_upload_cols = ['ra', 'dec', 'redshift', 'psfMag_r', 'u', 'g', 'r', 'i', 'z']

accepted_exts = [".csv"]
extension = Path(payload.filename).suffix
if extension.lower() not in accepted_exts:
raise HTTPException(
status_code=422,
detail=f"Uploaded data must be in '.csv' format, got {extension} instead"
)

try:
df = pd.read_csv(payload.file)
except Exception as e:
raise HTTPException(status_code=422, detail=f"Failed to parse CSV file tracking: {str(e)}")
df = upload_validator(df, expected_upload_cols)

# Feature Engineering (Vectorized)
df['u_g_color'] = df['u'] - df['g']
df['g_r_color'] = df['g'] - df['r']
df['r_i_color'] = df['r'] - df['i']
df['i_z_color'] = df['i'] - df['z']
df = df.drop(columns=['u', 'g', 'r', 'i', 'z'])

# Reorder columns to match the pipeline's expected order (excluding 'class')
model_features = [col for col in column_names.tolist() if col != "class"]
df = df[model_features]

pred_label:list[float] = pipe.predict(df).tolist()
pred_proba:list[list[float]] = pipe.predict_proba(df).tolist()

# Postprocessing
label_map = {0: "GALAXY", 1: "STAR", 2: "QSO"}
pred_label:list[str] = [label_map.get(pred) for pred in pred_label]
pred_proba = [[round(r, 3) for r in pred] for pred in pred_proba]

msg = {
"message": "batch prediction successful",
"prediction": pred_label,
"probabilities": pred_proba
}
return JSONResponse(
status_code=201, content=msg
)
185 changes: 185 additions & 0 deletions app/static/script.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ document.addEventListener("DOMContentLoaded", function() {
// Initialize components
initTabNavigation();
initPredictForm();
initBatchPredictForm();
initKeyboardShortcuts();
});

Expand Down Expand Up @@ -65,6 +66,14 @@ function initPredictForm() {
const formData = new FormData(form);
const dataObj = Object.fromEntries(formData.entries());

// Manual validation check to ensure no empty values
const requiredFields = ["ra", "dec", "redshift", "psfMag_r", "u", "g", "r", "i", "z"];
for (let field of requiredFields) {
if (dataObj[field] === undefined || dataObj[field] === "") {
throw new Error("Please fill in all inputs before analyzing.");
}
}

// Convert numeric strings to numbers
for (let key in dataObj) {
if (!isNaN(dataObj[key]) && dataObj[key] !== "") {
Expand Down Expand Up @@ -286,3 +295,179 @@ window.CosmoClassifier = {
formatNumber,
debounce
};

/**
* Batch Prediction Form Handler
*/
function initBatchPredictForm() {
const fileInput = document.getElementById("batchFile");
const dropzone = document.getElementById("fileUploadDropzone");
const fileInfo = document.getElementById("fileInfo");
const selectedFileName = document.getElementById("selectedFileName");
const selectedFileSize = document.getElementById("selectedFileSize");
const removeFileBtn = document.getElementById("removeFileBtn");
const submitBtn = document.getElementById("batchPredictBtn");
const form = document.getElementById("batchPredictForm");

if(!fileInput) return;

const awaitingBatch = document.getElementById("awaiting-batch");
const batchResult = document.getElementById("batch-result");
const tbody = document.getElementById("batchTableBody");

const MAX_SIZE = 5 * 1024 * 1024; // 5MB

function handleFile(file) {
if (!file) return;

if (!file.name.toLowerCase().endsWith('.csv')) {
showToast("Only .csv files are allowed", "error");
fileInput.value = "";
return;
}

if (file.size > MAX_SIZE) {
showToast("File size exceeds 5MB limit", "error");
fileInput.value = "";
return;
}

selectedFileName.textContent = file.name;
selectedFileSize.textContent = (file.size / 1024 / 1024).toFixed(2) + " MB";

dropzone.classList.add("hidden");
fileInfo.classList.remove("hidden");
submitBtn.disabled = false;
}

fileInput.addEventListener("change", (e) => {
handleFile(e.target.files[0]);
});

['dragenter', 'dragover', 'dragleave', 'drop'].forEach(eventName => {
dropzone.addEventListener(eventName, preventDefaults, false);
});

function preventDefaults(e) {
e.preventDefault();
e.stopPropagation();
}

['dragenter', 'dragover'].forEach(eventName => {
dropzone.addEventListener(eventName, () => {
dropzone.style.borderColor = "var(--accent-secondary)";
}, false);
});

['dragleave', 'drop'].forEach(eventName => {
dropzone.addEventListener(eventName, () => {
dropzone.style.borderColor = "";
}, false);
});

dropzone.addEventListener('drop', (e) => {
let dt = e.dataTransfer;
let files = dt.files;
if (files.length) {
fileInput.files = files;
handleFile(files[0]);
}
}, false);

removeFileBtn.addEventListener("click", () => {
fileInput.value = "";
dropzone.classList.remove("hidden");
fileInfo.classList.add("hidden");
submitBtn.disabled = true;
});

form.addEventListener("submit", async (e) => {
e.preventDefault();

if (!fileInput.files[0]) return;

setLoadingState(submitBtn, true);
awaitingBatch.style.display = 'none';
batchResult.classList.add("hidden");

try {
const formData = new FormData();
formData.append("payload", fileInput.files[0]);

const response = await fetch("/predict/file", {
method: "POST",
body: formData
});

if (!response.ok) {
const errorData = await response.json();
throw new Error(errorData.detail || errorData.message || "Batch prediction failed");
}

const data = await response.json();

renderBatchResults(data);
showToast('Batch classification complete!', 'success');

} catch (error) {
console.error(error);
showToast(`Error: ${error.message}`, 'error');
awaitingBatch.style.display = 'flex';
} finally {
setLoadingState(submitBtn, false);
}
});

let batchChart = null;

function renderBatchResults(data) {
const preds = data.prediction;
const probs = data.probabilities;

tbody.innerHTML = "";

let counts = { "GALAXY": 0, "STAR": 0, "QSO": 0 };

preds.forEach((pred, index) => {
counts[pred] = (counts[pred] || 0) + 1;

const maxProb = (Math.max(...probs[index]) * 100).toFixed(1);

const tr = document.createElement("tr");
tr.innerHTML = `
<td>Row ${index + 1}</td>
<td><span class="batch-badge batch-${pred.toLowerCase()}">${pred}</span></td>
<td>${maxProb}%</td>
`;
tbody.appendChild(tr);
});

batchResult.classList.remove("hidden");

// Render Chart
const ctx = document.getElementById('batchPieChart').getContext('2d');
if (batchChart) batchChart.destroy();

batchChart = new Chart(ctx, {
type: 'doughnut',
data: {
labels: ['GALAXY', 'STAR', 'QSO'],
datasets: [{
data: [counts['GALAXY'], counts['STAR'], counts['QSO']],
backgroundColor: ['#7000ff', '#00d4ff', '#ff6b6b'],
borderWidth: 0
}]
},
options: {
responsive: true,
maintainAspectRatio: false,
plugins: {
legend: {
position: 'bottom',
labels: { color: '#ffffff' }
}
}
}
});
}
}
Loading
Loading