import time
from typing import List, Tuple, Union

import numpy as np
import torch
from ultralytics import YOLO

logger = app_logger.getChild("models.detector.ultralytics")

class YOLOInference(BaseInference):
    def __init__(self, model_path: str, imsz: int = 640, 
                 conf_threshold: float = 0.25, nms_threshold: float = 0.45, 
                 device: str = "cpu"):
        """
        Initializing the YOLO class using the official Ultralytics SDK.
        
        Args:
            model_path: Path to the model file (.pt, .onnx, or .torchscript).
            imsz: Input image size for the model.
            conf_threshold: Confidence threshold to filter out low-confidence boxes.
            nms_threshold: IoU threshold for Non-Maximum Suppression.
            device: Computing device ('cpu' or 'cuda').
        """
        super().__init__(config={"device": device})
        
        self.model_path = model_path
        self.imsz = imsz
        self.conf_threshold = conf_threshold
        self.nms_threshold = nms_threshold
        
        self.load_model(model_path)

    def load_model(self, model_path: str):
        """
        Loads the model into memory. Ultralytics handle various formats automatically.
        """
        logger.info(f"[load] Loading Ultralytics model from {model_path} on {self.device}")
        # The YOLO class automatically handles weights and architecture configuration
        self.model = YOLO(model_path)
        self.model.to(self.device)

    def predict(self, im_bgr: Union[np.ndarray, List[np.ndarray]]) -> List[List[YOLOResult]]:
        """
        Performs end-to-end inference including preprocessing, model forward pass, and NMS.

        Args:
            im_bgr: A single image or a list of images in BGR format (numpy arrays).

        Returns:
            A list of lists containing YOLOResult objects for each input image.
        """
        if isinstance(im_bgr, np.ndarray):
            im_bgr = [im_bgr]
            
        start_time = time.time()
        logger.debug(f"[infer] Starting detector inference on {len(im_bgr)} frame(s)")

        final_results = []
        
        try:
            # Ultralytics .predict() handles letterboxing, normalization, and NMS internally.
            # It also automatically scales coordinates back to the original image size.
            results = self.model.predict(
                source=im_bgr,
                imgsz=self.imsz,
                conf=self.conf_threshold,
                iou=self.nms_threshold,
                device=self.device,
                verbose=False,
                save=False
            )

            for i, res in enumerate(results):
                # res.boxes.data contains [x1, y1, x2, y2, confidence, class_id]
                boxes_data = res.boxes.data.cpu().numpy()
                
                frame_results = []
                for box in boxes_data:
                    # box[:5] extract [x1, y1, x2, y2, confidence]
                    # We pass the scaled coordinates and the original image to your YOLOResult wrapper
                    frame_results.append(YOLOResult(box[:5], im_bgr[i]))
                
                final_results.append(frame_results)

            return final_results

        except Exception as e:
            logger.error(f"Inference error occurred: {e}")
            # Return empty lists to prevent the pipeline from breaking
            return [[] for _ in range(len(im_bgr))]
            
        finally:
            logger.info(
                f"[infer] Detector inference completed in {(time.time() - start_time) * 1000:.2f} ms "
                f"for {len(im_bgr)} frame(s)"
            )