import os
import shutil
import uuid
import asyncio
from typing import List, Optional, Literal, Dict, Any
from pathlib import Path
import tempfile
import atexit
import threading
import subprocess
import tempfile


import torch
import numpy as np
import imageio 
from fastapi import FastAPI, UploadFile, File, HTTPException, BackgroundTasks, Form, Query, Body

from fastapi.responses import FileResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from PIL import Image
import uvicorn

from trellis.pipelines import TrellisImageTo3DPipeline
from trellis.representations import Gaussian, MeshExtractResult
from trellis.utils import render_utils, postprocessing_utils
from easydict import EasyDict as edict


os.environ['SPCONV_ALGO'] = 'native'


MAX_SEED = np.iinfo(np.int32).max
TMP_DIR = os.path.join(tempfile.gettempdir(), 'trellis_api')
os.makedirs(TMP_DIR, exist_ok=True)


pipeline = None


sessions: Dict[str, Dict[str, Any]] = {}


class ProgressTracker:
    def __init__(self, session_id: str):
        self.session_id = session_id
        self.current_step = 0
        self.total_steps = 0
        self.stage = "Initializing"
        self.message = "Starting generation..."
        self.progress_percent = 0.0
        
    def update(self, step: int, total: int, stage: str, message: str):
        self.current_step = step
        self.total_steps = total
        self.stage = stage
        self.message = message
        self.progress_percent = (step / total * 100.0) if total > 0 else 0.0
        
        
        if self.session_id in sessions:
            sessions[self.session_id]['progress'] = {
                'current_step': self.current_step,
                'total_steps': self.total_steps,
                'stage': self.stage,
                'message': self.message,
                'progress_percent': self.progress_percent,
                'status': 'processing'
            }


class GenerationSettings(BaseModel):
    seed: int = Field(default=0, ge=0, le=MAX_SEED)
    randomize_seed: bool = Field(default=True)
    ss_guidance_strength: float = Field(default=7.5, ge=0.0, le=10.0)
    ss_sampling_steps: int = Field(default=12, ge=1, le=100)
    slat_guidance_strength: float = Field(default=3.0, ge=0.0, le=10.0)
    slat_sampling_steps: int = Field(default=12, ge=1, le=100)
    multiimage_algo: Literal["multidiffusion", "stochastic"] = Field(default="stochastic")

class GLBExtractionSettings(BaseModel):
    mesh_simplify: float = Field(default=0.95, ge=0.9, le=0.98)
    texture_size: int = Field(default=1024, ge=512, le=2048)

class GenerationResponse(BaseModel):
    session_id: str
    video_path: str
    message: str
    status: str = "processing"

class ExtractionResponse(BaseModel):
    session_id: str
    file_path: str
    download_url: str
    message: str
    status: str = "completed"

class ProgressResponse(BaseModel):
    session_id: str
    current_step: int
    total_steps: int
    stage: str
    message: str
    progress_percent: float
    status: str
    video_ready: bool = False
    video_url: Optional[str] = None


app = FastAPI(
    title="TRELLIS Image to 3D API",
    description="API server for converting images to 3D models using TRELLIS",
    version="1.0.0"
)


app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def cleanup_sessions():
    """Clean up old session directories"""
    global sessions
    for session_id in list(sessions.keys()):
        cleanup_session(session_id)

def cleanup_session(session_id: str):
    """Clean up a specific session"""
    global sessions
    if session_id in sessions:
        user_dir = os.path.join(TMP_DIR, session_id)
        if os.path.exists(user_dir):
            shutil.rmtree(user_dir, ignore_errors=True)
        del sessions[session_id]

def pack_state(gs: Gaussian, mesh: MeshExtractResult) -> dict:
    """Pack Gaussian and mesh data into serializable format"""
    return {
        'gaussian': {
            **gs.init_params,
            '_xyz': gs._xyz.cpu().numpy(),
            '_features_dc': gs._features_dc.cpu().numpy(),
            '_scaling': gs._scaling.cpu().numpy(),
            '_rotation': gs._rotation.cpu().numpy(),
            '_opacity': gs._opacity.cpu().numpy(),
        },
        'mesh': {
            'vertices': mesh.vertices.cpu().numpy(),
            'faces': mesh.faces.cpu().numpy(),
        },
    }

def unpack_state(state: dict) -> tuple:
    """Unpack serialized state back to Gaussian and mesh objects"""
    gs = Gaussian(
        aabb=state['gaussian']['aabb'],
        sh_degree=state['gaussian']['sh_degree'],
        mininum_kernel_size=state['gaussian']['mininum_kernel_size'],
        scaling_bias=state['gaussian']['scaling_bias'],
        opacity_bias=state['gaussian']['opacity_bias'],
        scaling_activation=state['gaussian']['scaling_activation'],
    )
    gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda')
    gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda')
    gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda')
    gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda')
    gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda')
    
    mesh = edict(
        vertices=torch.tensor(state['mesh']['vertices'], device='cuda'),
        faces=torch.tensor(state['mesh']['faces'], device='cuda'),
    )
    
    return gs, mesh

def get_seed(randomize_seed: bool, seed: int) -> int:
    """Get the random seed"""
    return np.random.randint(0, MAX_SEED) if randomize_seed else seed

async def load_image_from_upload(upload_file: UploadFile) -> Image.Image:
    """Load PIL Image from uploaded file"""
    contents = await upload_file.read()
    image = Image.open(io.BytesIO(contents)).convert('RGBA')
    return image

def generate_3d_with_progress(session_id: str, pil_images: list, settings: GenerationSettings):
    """Generate 3D model with progress tracking"""
    try:
        progress = ProgressTracker(session_id)
        user_dir = os.path.join(TMP_DIR, session_id)
        
        
        sessions[session_id]['progress'] = {
            'current_step': 0,
            'total_steps': 100,
            'stage': 'Initializing',
            'message': 'Starting generation...',
            'progress_percent': 0.0,
            'status': 'processing'
        }
        
        
        actual_seed = get_seed(settings.randomize_seed, settings.seed)
        print(f"Using seed: {actual_seed}")
        
        
        total_steps = settings.ss_sampling_steps + settings.slat_sampling_steps + 20  
        progress.update(0, total_steps, "Preprocessing", "Preparing images...")
        
        
        progress.update(5, total_steps, "Sparse Structure", "Generating sparse structure...")
        
        if len(pil_images) == 1:
            
            outputs = pipeline.run(
                pil_images[0],
                seed=actual_seed,
                formats=["gaussian", "mesh"],
                preprocess_image=False,
                sparse_structure_sampler_params={
                    "steps": settings.ss_sampling_steps,
                    "cfg_strength": settings.ss_guidance_strength,
                },
                slat_sampler_params={
                    "steps": settings.slat_sampling_steps,
                    "cfg_strength": settings.slat_guidance_strength,
                },
            )
        else:
            
            outputs = pipeline.run_multi_image(
                pil_images,
                seed=actual_seed,
                formats=["gaussian", "mesh"],
                preprocess_image=False,
                sparse_structure_sampler_params={
                    "steps": settings.ss_sampling_steps,
                    "cfg_strength": settings.ss_guidance_strength,
                },
                slat_sampler_params={
                    "steps": settings.slat_sampling_steps,
                    "cfg_strength": settings.slat_guidance_strength,
                },
                mode=settings.multiimage_algo,
            )
         
        progress.update(total_steps - 15, total_steps, "Rendering", "Creating preview video...")
        
        
        video = render_utils.render_video(outputs['gaussian'][0], num_frames=30)['color']  
        video_geo = render_utils.render_video(outputs['mesh'][0], num_frames=30)['normal']  
        
        
        import cv2
        video_resized = []
        video_geo_resized = [] 
        
        target_width, target_height = 256, 256  
        
        for i in range(len(video)):
            
            frame_resized = cv2.resize(video[i], (target_width, target_height), interpolation=cv2.INTER_LANCZOS4)
            video_resized.append(frame_resized)
            
            
            geo_frame_resized = cv2.resize(video_geo[i], (target_width, target_height), interpolation=cv2.INTER_LANCZOS4)
            video_geo_resized.append(geo_frame_resized)
        
        
        video_combined = [np.concatenate([video_resized[i], video_geo_resized[i]], axis=1) for i in range(len(video_resized))]
        
        video_path = os.path.join(user_dir, 'preview.mp4')
        
        
        video_encoded = False
        
        
        try:
            writer = imageio.get_writer(
                video_path, 
                fps=10,  
                codec='libx264',
                output_params=[
                    '-an',
                    '-profile:v', 'baseline',      
                    '-level', '3.0',               
                    '-pix_fmt', 'yuv420p',         
                    '-preset', 'ultrafast',        
                    '-tune', 'zerolatency',        
                    '-crf', '28',                  
                    '-movflags', '+faststart',     
                    '-max_muxing_queue_size', '1024',  
                    '-color_primaries', 'bt709',
                    '-color_trc',       'bt709',
                    '-colorspace',      'bt709',
                ]
            )
            for frame in video_combined:
                
                if frame.dtype != np.uint8:
                    frame = (frame * 255).astype(np.uint8)
                writer.append_data(frame)
            writer.close()
            video_encoded = True
            print(f"Video encoded with ultra-fast settings: {len(video_combined)} frames at {target_width*2}x{target_height}")
        except Exception as e:
            print(f"Ultra-fast encoding failed: {e}")
        
        
        if not video_encoded:
            try:
                imageio.mimsave(
                    video_path, 
                    video_combined, 
                    fps=10,        
                    quality=6,     
                    macro_block_size=None  
                )
                video_encoded = True
                print("Video encoded with basic optimized settings")
            except Exception as e:
                print(f"Basic MP4 encoding failed: {e}")
        
        
        if not video_encoded:
            try:
                gif_path = os.path.join(user_dir, 'preview.gif')
                
                gif_frames = video_combined[::3]  
                gif_frames_tiny = []
                for frame in gif_frames[:15]:  
                    if len(frame.shape) == 3 and frame.shape[2] == 3:
                        
                        tiny_frame = cv2.resize(frame, (128, 64), interpolation=cv2.INTER_NEAREST)  
                        gif_frames_tiny.append(tiny_frame)
                
                imageio.mimsave(gif_path, gif_frames_tiny, fps=5, duration=0.2)  
                video_path = gif_path
                video_encoded = True
                print("Created tiny GIF fallback")
            except Exception as e:
                print(f"GIF creation failed: {e}")
                
        if not video_encoded:
            raise Exception("All video encoding methods failed")

        
        progress.update(total_steps - 5, total_steps, "Finalizing", "Saving model state...")
        
        
        state = pack_state(outputs['gaussian'][0], outputs['mesh'][0])
        sessions[session_id].update({
            'state': state,
            'user_dir': user_dir,
            'seed': actual_seed,
            'status': 'completed',
            'video_ready': True,
            'video_url': f"/download/{session_id}/preview.mp4"
        })
        
        
        progress.update(total_steps, total_steps, "Complete", "3D model generated successfully!")
        sessions[session_id]['progress']['status'] = 'completed'
        sessions[session_id]['progress']['video_ready'] = True
        sessions[session_id]['progress']['video_url'] = f"/download/{session_id}/preview.mp4"
        
        
        torch.cuda.empty_cache()
        
    except Exception as e:
        
        if session_id in sessions:
            sessions[session_id]['progress'] = {
                'current_step': 0,
                'total_steps': 100,
                'stage': 'Error',
                'message': f'Generation failed: {str(e)}',
                'progress_percent': 0.0,
                'status': 'error'
            }
            sessions[session_id]['status'] = 'error'

@app.on_event("startup")
async def startup_event():
    """Initialize the pipeline on startup"""
    global pipeline
    print("Loading TRELLIS pipeline...")
    pipeline = TrellisImageTo3DPipeline.from_pretrained("JeffreyXiang/TRELLIS-image-large")
    pipeline.cuda()
    print("Pipeline loaded successfully!")

@app.on_event("shutdown")
async def shutdown_event():
    """Cleanup on shutdown"""
    cleanup_sessions()


atexit.register(cleanup_sessions)

@app.get("/")
async def root():
    """Health check endpoint"""
    return {"message": "TRELLIS Image to 3D API is running", "status": "healthy"}

@app.get("/health")
async def health_check():
    """Detailed health check"""
    return {
        "status": "healthy",
        "pipeline_loaded": pipeline is not None,
        "cuda_available": torch.cuda.is_available(),
        "active_sessions": len(sessions)
    }

@app.get("/download/{session_id}/{filename:path}")  
async def download_file(session_id: str, filename: str):
    """Download generated files - supports paths like textures/Image_0.png"""
    try:
        if session_id not in sessions:
            base_path = os.path.join(TMP_DIR, session_id)
        else:
            base_path = sessions[session_id]['user_dir']
            
        file_path = os.path.join(base_path, filename)
        
        # Security check
        file_path = os.path.normpath(file_path)
        base_path = os.path.normpath(base_path)
        if not file_path.startswith(base_path):
            raise HTTPException(status_code=403, detail="Access denied")
        
        if not os.path.exists(file_path):
            raise HTTPException(status_code=404, detail=f"File not found: {filename}")
        
        return FileResponse(
            path=file_path,
            filename=os.path.basename(filename),
            media_type='application/octet-stream'
        )
    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error downloading file: {str(e)}")
    
@app.post("/extract_fbx", response_model=ExtractionResponse)
async def extract_fbx(
    session_id: str = Query(..., description="Session ID"),  # Query parameter
    settings: GLBExtractionSettings = Body(default=GLBExtractionSettings())  # Body parameter with defaults
):
    """Extract FBX file from generated 3D model"""
    try:
        if session_id not in sessions:
            raise HTTPException(status_code=404, detail="Session not found")
        
        data = sessions[session_id]
        if data["status"] != "completed":
            raise HTTPException(status_code=400, detail="Model not ready yet")
        
        # Setup paths
        tmp_dir = data["user_dir"]
        tmp_glb = os.path.join(tmp_dir, "tmp_model.glb")
        final_fbx = os.path.join(tmp_dir, "model.fbx")

        # Generate GLB first
        gs, mesh = unpack_state(data["state"])
        glb = postprocessing_utils.to_glb(
            gs, mesh,
            simplify=settings.mesh_simplify,
            texture_size=settings.texture_size,
            verbose=False
        )
        glb.export(tmp_glb)

        # Check GLB file
        glb_size = os.path.getsize(tmp_glb) / (1024 * 1024)
        print(f"GLB file size: {glb_size:.2f} MB")
        if glb_size < 1.0:
            print("Warning: GLB seems small, textures might be missing")

        # Copy Blender script
        script_path = os.path.join(tmp_dir, "blender_convert.py")
        with open("blender_convert.py", "r", encoding="utf-8") as src, open(script_path, "w", encoding="utf-8") as dst:
            dst.write(src.read())

        # Run Blender conversion
        subprocess.run([
            "blender", "--background",
            "--python", script_path,
            "--", tmp_glb, final_fbx
        ], check=True)
        
        # Check for textures
        textures_dir = os.path.join(tmp_dir, "textures")
        texture_files = []
        if os.path.exists(textures_dir):
            texture_files = [f for f in os.listdir(textures_dir) if f.endswith(('.png', '.jpg', '.jpeg'))]
        
        return ExtractionResponse(
            session_id=session_id,
            file_path=final_fbx,
            download_url=f"/download/{session_id}/model.fbx",
            message=f"FBX extracted with {len(texture_files)} texture(s)",
            status="completed"
        )
        
    except subprocess.CalledProcessError as e:
        raise HTTPException(status_code=500, detail=f"Blender conversion failed: {str(e)}")
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error extracting FBX: {str(e)}")



@app.post("/generate_3d", response_model=GenerationResponse)
async def generate_3d(
    background_tasks: BackgroundTasks,
    images: List[UploadFile] = File(...),
    
    seed: int = Form(0),
    randomize_seed: bool = Form(True),
    ss_guidance_strength: float = Form(7.5),
    ss_sampling_steps: int = Form(12),
    slat_guidance_strength: float = Form(3.0),
    slat_sampling_steps: int = Form(12),
    multiimage_algo: str = Form("stochastic")
):
    """Generate 3D model from single or multiple images"""
    try:
        if not images:
            raise HTTPException(status_code=400, detail="No images provided")
        
        
        if not (0 <= seed <= MAX_SEED):
            raise HTTPException(status_code=400, detail=f"Seed must be between 0 and {MAX_SEED}")
        if not (0.0 <= ss_guidance_strength <= 10.0):
            raise HTTPException(status_code=400, detail="SS guidance strength must be between 0.0 and 10.0")
        if not (1 <= ss_sampling_steps <= 100):
            raise HTTPException(status_code=400, detail="SS sampling steps must be between 1 and 100")
        if not (0.0 <= slat_guidance_strength <= 10.0):
            raise HTTPException(status_code=400, detail="SLAT guidance strength must be between 0.0 and 10.0")
        if not (1 <= slat_sampling_steps <= 100):
            raise HTTPException(status_code=400, detail="SLAT sampling steps must be between 1 and 100")
        if multiimage_algo not in ["multidiffusion", "stochastic"]:
            raise HTTPException(status_code=400, detail="multiimage_algo must be 'multidiffusion' or 'stochastic'")
        
        
        settings = GenerationSettings(
            seed=seed,
            randomize_seed=randomize_seed,
            ss_guidance_strength=ss_guidance_strength,
            ss_sampling_steps=ss_sampling_steps,
            slat_guidance_strength=slat_guidance_strength,
            slat_sampling_steps=slat_sampling_steps,
            multiimage_algo=multiimage_algo
        )
        
        
        print(f"=== RECEIVED SETTINGS FROM UNITY ===")
        print(f"seed: {seed}")
        print(f"randomize_seed: {randomize_seed}")
        print(f"ss_guidance_strength: {ss_guidance_strength}")
        print(f"ss_sampling_steps: {ss_sampling_steps}")
        print(f"slat_guidance_strength: {slat_guidance_strength}")
        print(f"slat_sampling_steps: {slat_sampling_steps}")
        print(f"multiimage_algo: {multiimage_algo}")
        print(f"=== END RECEIVED SETTINGS ===")
        
        
        session_id = str(uuid.uuid4())
        user_dir = os.path.join(TMP_DIR, session_id)
        os.makedirs(user_dir, exist_ok=True)
        
        
        sessions[session_id] = {
            'user_dir': user_dir,
            'status': 'processing',
            'video_ready': False,
            'progress': {
                'current_step': 0,
                'total_steps': 100,
                'stage': 'Initializing',
                'message': 'Starting generation...',
                'progress_percent': 0.0,
                'status': 'processing'
            }
        }
        
        
        pil_images = []
        import io
        for img_file in images:
            contents = await img_file.read()
            pil_image = Image.open(io.BytesIO(contents)).convert('RGBA')
            processed_image = pipeline.preprocess_image(pil_image)
            pil_images.append(processed_image)
        
        
        background_tasks.add_task(generate_3d_with_progress, session_id, pil_images, settings)
        
        
        background_tasks.add_task(cleanup_session_delayed, session_id, 3600)
        
        return GenerationResponse(
            session_id=session_id,
            video_path=f"/download/{session_id}/preview.mp4",
            message="3D model generation started",
            status="processing"
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error starting 3D model generation: {str(e)}")


@app.get("/progress/{session_id}", response_model=ProgressResponse)
async def get_progress(session_id: str):
    """Get generation progress for a session"""
    try:
        if session_id not in sessions:
            raise HTTPException(status_code=404, detail="Session not found")
        
        session_data = sessions[session_id]
        progress = session_data.get('progress', {})
        
        return ProgressResponse(
            session_id=session_id,
            current_step=progress.get('current_step', 0),
            total_steps=progress.get('total_steps', 100),
            stage=progress.get('stage', 'Unknown'),
            message=progress.get('message', 'No progress information'),
            progress_percent=progress.get('progress_percent', 0.0),
            status=progress.get('status', 'unknown'),
            video_ready=session_data.get('video_ready', False),
            video_url=session_data.get('video_url')
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error getting progress: {str(e)}")

@app.post("/extract_glb", response_model=ExtractionResponse)
async def extract_glb(
    session_id: str = Query(..., description="Session ID"),  # Query parameter
    settings: GLBExtractionSettings = Body(default=GLBExtractionSettings())  # Body parameter with defaults
):
    """Extract GLB file from generated 3D model"""
    try:
        if session_id not in sessions:
            raise HTTPException(status_code=404, detail="Session not found")
        
        session_data = sessions[session_id]
        if session_data.get('status') != 'completed':
            raise HTTPException(status_code=400, detail="Model generation not completed yet")
        
        state = session_data['state']
        user_dir = session_data['user_dir']
        
        # Unpack state
        gs, mesh = unpack_state(state)
        
        # Generate GLB
        glb = postprocessing_utils.to_glb(
            gs, mesh,
            simplify=settings.mesh_simplify,
            texture_size=settings.texture_size,
            verbose=False
        )
        
        # Save GLB file
        glb_path = os.path.join(user_dir, 'model.glb')
        glb.export(glb_path)
        
        # Clean up GPU memory
        torch.cuda.empty_cache()
        
        return ExtractionResponse(
            session_id=session_id,
            file_path=glb_path,
            download_url=f"/download/{session_id}/model.glb",
            message="GLB file extracted successfully"
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error extracting GLB: {str(e)}")


@app.post("/extract_gaussian", response_model=ExtractionResponse)
async def extract_gaussian(session_id: str = Query(..., description="Session ID")):
    """Extract Gaussian PLY file from generated 3D model"""
    try:
        if session_id not in sessions:
            raise HTTPException(status_code=404, detail="Session not found")
        
        session_data = sessions[session_id]
        if session_data.get('status') != 'completed':
            raise HTTPException(status_code=400, detail="Model generation not completed yet")
        
        state = session_data['state']
        user_dir = session_data['user_dir']
        
        # Unpack state
        gs, _ = unpack_state(state)
        
        # Save PLY file
        ply_path = os.path.join(user_dir, 'model.ply')
        gs.save_ply(ply_path)
        
        # Clean up GPU memory
        torch.cuda.empty_cache()
        
        return ExtractionResponse(
            session_id=session_id,
            file_path=ply_path,
            download_url=f"/download/{session_id}/model.ply",
            message="Gaussian PLY file extracted successfully"
        )
        
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error extracting Gaussian: {str(e)}")



async def cleanup_session_delayed(session_id: str, delay: int):
    """Cleanup session after delay"""
    await asyncio.sleep(delay)
    if session_id in sessions:
        cleanup_session(session_id)

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description="TRELLIS Image to 3D API Server")
    parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
    parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
    parser.add_argument("--reload", action="store_true", help="Enable auto-reload for development")
    
    args = parser.parse_args()
    
    uvicorn.run(
        "app:app" if args.reload else app,
        host=args.host,
        port=args.port,
        reload=args.reload
    )
