#!/usr/bin/env python3
"""
Tree Packing Optimizer with Bisection Search

Optimizes n-tree packing by:
1. Taking best (n-1) packings from pool
2. Blowing them up and adding an extra tree
3. Running sparrow + SA optimization
4. Using bisection to find optimal area_factor
"""

import argparse
import glob
import json
import os
import shutil
import subprocess
import sys
import tempfile
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
import numpy as np

# Add parent directory to path for santa.py imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from santa import Package


# Paths (relative to script directory)
SCRIPT_DIR = Path(__file__).parent.resolve()
SPARROW_BIN = SCRIPT_DIR / "sparrow_c_8w" / "target" / "release" / "sparrow"
SA_BIN = SCRIPT_DIR / "sa_fast_v2_json"


def get_all_candidates(pool_dir: str, n_minus_1: int) -> List[Tuple[float, Path]]:
    """Get all JSON files from pool folder with their square_side scores."""
    pool_path = Path(pool_dir)
    pattern = f"n{n_minus_1}_*.json"
    candidates = list(pool_path.glob(pattern))
    
    if len(candidates) == 0:
        raise ValueError(f"No candidates in pool for n={n_minus_1}")
    
    # Parse square_side from each candidate
    scored = []
    for path in candidates:
        with open(path) as f:
            data = json.load(f)
        square_side = data["solution"]["square_side"]
        scored.append((square_side, path))
    
    return scored


def softmax_select(candidates: List[Tuple[float, Path]], beta: float = 10.0) -> Path:
    """
    Select a candidate probabilistically using softmax over square_side.
    
    Lower square_side = higher probability (we negate for minimization).
    beta: temperature parameter (higher = more greedy toward best candidate).
    """
    scores = np.array([s for s, _ in candidates])
    
    # Negate because we want to minimize square_side
    # Subtract max for numerical stability
    neg_scores = -beta * scores
    neg_scores = neg_scores - np.max(neg_scores)
    
    probs = np.exp(neg_scores)
    probs = probs / np.sum(probs)
    
    # Sample one candidate
    idx = np.random.choice(len(candidates), p=probs)
    return candidates[idx][1]


def blow_up_and_add_tree(json_path: Path, area_factor: float, n: int, 
                         output_path: Path) -> float:
    """
    Blow up packing and add an extra tree to create n-tree packing.
    Returns the original square_side for comparison.
    """
    # Load and get original square_side
    with open(json_path) as f:
        original_data = json.load(f)
    original_square_side = original_data["solution"]["square_side"]
    
    # Use Package class to blow up
    pkg = Package.from_sparrow(json_path)
    blown_up = pkg.blow_up(area_factor*n/(n-1))
    
    # Write blown-up package to temp file to get JSON format
    with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
        tmp_path = tmp.name
    
    blown_up.to_sparrow(tmp_path)
    
    # Read the blown-up JSON
    with open(tmp_path) as f:
        data = json.load(f)
    os.unlink(tmp_path)
    
    # Modify for n trees (step 3.1-3.3)
    data["name"] = f"n{n}_sqpp"
    data["items"][0]["demand"] = n
    
    # Find tree with max translation[0] * translation[1] and duplicate it
    placed_items = data["solution"]["layout"]["placed_items"]
    max_product = float('-inf')
    best_item = None
    for item in placed_items:
        trans = item["transformation"]["translation"]
        product = trans[0] * trans[1]
        if product > max_product:
            max_product = product
            best_item = item
    
    # Add duplicate (deep copy)
    import copy
    duplicate = copy.deepcopy(best_item)
    placed_items.append(duplicate)
    
    # Write output
    with open(output_path, 'w') as f:
        json.dump(data, f, indent=2)
    
    return original_square_side


# def run_sparrow(input_path: Path, output_path: Path, exploration_sec: int = 10) -> bool:
#     """Run sparrow optimizer with exploration phase only."""
#     cmd = [
#         str(SPARROW_BIN),
#         "-i", str(input_path),
#         "-e", str(exploration_sec),
#         "-c", "0"
#     ]
    
#     try:
#         result = subprocess.run(cmd, capture_output=True, text=True, timeout=exploration_sec + 30)
        
#         # Sparrow writes output to same directory with modified filename
#         # We need to find the output file - it's typically the same name or with suffix
#         # Actually, let's check what sparrow produces
#         possible_outputs = [
#             input_path.with_suffix('.solution.json'),
#             input_path.parent / "output" / input_path.name,
#             input_path  # It might overwrite in place
#         ]
        
#         # The output is likely written in the output folder or same location
#         output_dir = SCRIPT_DIR / "sparrow_c_8w" / "output"
#         if output_dir.exists():
#             for f in output_dir.glob("*.json"):
#                 shutil.copy(f, output_path)
#                 return True
        
#         # If sparrow writes to input location, copy it
#         if input_path.exists():
#             shutil.copy(input_path, output_path)
#             return True
            
#         return False
#     except subprocess.TimeoutExpired:
#         return False
#     except Exception as e:
#         print(f"Sparrow error: {e}")
#         return False


def run_sa_single(input_path: Path, output_path: Path, seed: int, 
                  iterations: int = 350000) -> Optional[float]:
    """Run single SA optimization, return square_side or None on failure."""
    cmd = [
        str(SA_BIN),
        "-i", str(input_path),
        "-o", str(output_path),
        "-iter", str(iterations),
        "-seed", str(seed)
    ]
    
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, timeout=600)
        
        if output_path.exists():
            with open(output_path) as f:
                data = json.load(f)
            return data["solution"]["square_side"]
        return None
    except Exception as e:
        print(f"SA error (seed={seed}): {e}")
        return None


def run_sa_parallel(input_path: Path, output_dir: Path, num_seeds: int = 8,
                    iterations: int = 350000) -> Tuple[Optional[Path], float]:
    """
    Run sa_fast_v2_json with multiple seeds in parallel.
    Batches runs based on available CPU cores.
    Returns (best_output_path, best_square_side) or (None, inf) on failure.
    """
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Get number of available CPU cores
    num_cores = os.cpu_count() or 1
    print(f"Running {num_seeds} SA seeds with {num_cores} cores")
    
    output_paths = []
    
    # Process in batches of num_cores
    for batch_start in range(0, num_seeds, num_cores):
        batch_end = min(batch_start + num_cores, num_seeds)
        batch_seeds = range(batch_start, batch_end)
        
        # Start batch of processes
        processes = []
        for seed in batch_seeds:
            output_path = output_dir / f"sa_seed_{seed}.json"
            output_paths.append(output_path)
            
            cmd = [
                str(SA_BIN),
                "-i", str(input_path),
                "-o", str(output_path),
                "-iter", str(iterations),
                "-seed", str(seed)
            ]
            
            proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            processes.append(proc)
        
        # Wait for batch to complete
        for proc in processes:
            try:
                proc.wait(timeout=7200)  # 2 hour timeout for long runs
            except subprocess.TimeoutExpired:
                proc.kill()
    
    # Find best result
    best_path = None
    best_square_side = float('inf')
    
    for output_path in output_paths:
        if output_path.exists():
            try:
                with open(output_path) as f:
                    data = json.load(f)
                square_side = data["solution"]["square_side"]
                if square_side < best_square_side:
                    best_square_side = square_side
                    best_path = output_path
            except Exception:
                pass
    
    return best_path, best_square_side


@dataclass
class TrialResult:
    """Result from a single trial, used to track state for multi-round optimization."""
    original_side: float
    final_side: float
    sparrow_out_path: Optional[Path]  # Sparrow output (for resumption in later rounds)
    best_sa_path: Optional[Path]      # Best SA result
    success: bool


def run_single_trial(candidate_path: Path, area_factor: float, n: int,
                     trial_id: int, work_dir: Path, sa_iterations: int = 350000,
                     sparrow_exploration_sec: int = 10
                     ) -> TrialResult:
    """
    Run complete pipeline for one candidate.
    Returns TrialResult with all paths for later round resumption.
    """
    trial_dir = work_dir / f"trial_{trial_id}"
    trial_dir.mkdir(parents=True, exist_ok=True)
    
    # Step 1: Blow up and add tree
    blown_up_path = trial_dir / "blown_up.json"
    original_square_side = blow_up_and_add_tree(
        candidate_path, area_factor, n, blown_up_path
    )
    
    # Step 2: Run sparrow optimization
    sparrow_output = trial_dir / "sparrow_out.json"
    sparrow_log = SCRIPT_DIR / "output" / "log.txt"
    
    # Run sparrow with the blown-up input
    cmd = [
        str(SPARROW_BIN),
        "-i", str(blown_up_path),
        "-e", str(sparrow_exploration_sec),
        "-c", "0"
    ]
    
    sparrow_success = False
    try:
        result = subprocess.run(cmd, capture_output=True, text=True, 
                                timeout=sparrow_exploration_sec + 60,
                                cwd=str(SCRIPT_DIR))
        
        # Check if sparrow found a feasible solution by reading log.txt
        if sparrow_log.exists():
            with open(sparrow_log) as f:
                log_content = f.read()
            if "feasible solution found!" in log_content:
                sparrow_success = True
    except Exception as e:
        print(f"Sparrow failed in trial {trial_id}: {e}")
    
    if not sparrow_success:
        print(f"Trial {trial_id}: Sparrow did not find feasible solution")
        return TrialResult(original_square_side, float('inf'), None, None, False)
    
    # Sparrow outputs to output/final_<name>.json in SCRIPT_DIR
    sparrow_expected_output = SCRIPT_DIR / "output" / f"final_n{n}_sqpp.json"
    
    if sparrow_expected_output.exists():
        shutil.copy(sparrow_expected_output, sparrow_output)
    else:
        print(f"Warning: Sparrow output not found at {sparrow_expected_output}")
        return TrialResult(original_square_side, float('inf'), None, None, False)
    
    # Step 3: Run SA with parallel seeds (only if sparrow succeeded)
    sa_output_dir = trial_dir / "sa_outputs"
    best_path, new_square_side = run_sa_parallel(
        sparrow_output, sa_output_dir, 
        num_seeds=8, iterations=sa_iterations
    )
    
    return TrialResult(original_square_side, new_square_side, sparrow_output, best_path, True)


def bisection_evaluate(area_factor: float, pool_dir: str, n: int,
                       work_dir: Path, sa_iterations: int = 350000,
                       sparrow_exploration_sec: int = 10,
                       num_trials: int = 8, softmax_beta: float = 10.0
                       ) -> Tuple[int, List[TrialResult]]:
    """
    Evaluate area_factor: return (+1, results) if >2 improvements, else (-1, results).
    
    Uses softmax selection to probabilistically pick candidates for each trial.
    Returns all trial results for use in later rounds.
    """
    n_minus_1 = n - 1
    candidates = get_all_candidates(pool_dir, n_minus_1)
    
    results = []
    improvements = 0
    
    for trial_id in range(num_trials):
        # Probabilistically select a candidate using softmax
        candidate = softmax_select(candidates, beta=softmax_beta)
        
        trial_result = run_single_trial(
            candidate, area_factor, n, trial_id, work_dir, sa_iterations,
            sparrow_exploration_sec=sparrow_exploration_sec
        )
        results.append(trial_result)
        
        print(f"Trial {trial_id}: candidate={candidate.name}, area_factor={area_factor:.4f}, "
              f"original_side={trial_result.original_side:.4f}, new_side={trial_result.final_side:.4f}")
        
        # Count as improvement if optimization succeeded
        if trial_result.success:
            improvements += 1
    
    # Return +1 if more than 2 improvements, else -1
    sign = 1 if improvements >= 2 else -1
    return sign, results


def bisection_search(pool_dir: str, n: int, results_dir: Path,
                     low: float = 1.0333, high: float = 1.1, steps: int = 10,
                     sa_iterations: int = 350000,
                     sparrow_exploration_sec: int = 10) -> Tuple[float, List[TrialResult]]:
    """
    Round 1: Find optimal area_factor using bisection.
    Returns (best_area_factor, all_successful_results).
    """
    print(f"\n{'='*60}")
    print(f"ROUND 1: Bisection search in range [{low}, {high}] for n={n}")
    print(f"{'='*60}")
    
    all_results = []
    best_area_factor = (low + high) / 2
    
    for step in range(steps):
        mid = (low + high) / 2
        work_dir = results_dir / "round1" / f"step_{step}_af_{mid:.4f}"
        work_dir.mkdir(parents=True, exist_ok=True)
        
        print(f"\n=== Bisection Step {step+1}/{steps}: area_factor = {mid:.4f} ===")
        
        sign, step_results = bisection_evaluate(
            mid, pool_dir, n, work_dir, sa_iterations,
            sparrow_exploration_sec=sparrow_exploration_sec
        )
        
        # Collect successful results
        successful = [r for r in step_results if r.success]
        all_results.extend(successful)
        
        best_side = min((r.final_side for r in successful), default=float('inf'))
        print(f"Result: sign={sign:+d}, successful={len(successful)}/8, best_side={best_side:.4f}")
        
        # Bisection update: if +1 (too easy), decrease; if -1 (too hard), increase
        if sign > 0:
            high = mid
            best_area_factor = mid
        else:
            low = mid
    
    print(f"\nRound 1 complete: {len(all_results)} successful trials collected")
    return best_area_factor, all_results


def run_round2(round1_results: List[TrialResult], n: int, results_dir: Path,
               top_k: int = 16, sparrow_sec: int = 60, 
               sa_iterations: int = 3500000) -> List[TrialResult]:
    """
    Round 2: Take top-k results from Round 1, resume sparrow for longer, run SA with more iterations.
    """
    print(f"\n{'='*60}")
    print(f"ROUND 2: Top {top_k} results, sparrow {sparrow_sec}s, SA {sa_iterations} iter")
    print(f"{'='*60}")
    
    # Sort by final_side and take top-k
    sorted_results = sorted(round1_results, key=lambda r: r.final_side)[:top_k]
    print(f"Selected {len(sorted_results)} best results from Round 1")
    
    round2_dir = results_dir / "round2"
    round2_dir.mkdir(parents=True, exist_ok=True)
    
    round2_results = []
    sparrow_log = SCRIPT_DIR / "output" / "log.txt"
    
    for idx, r1_result in enumerate(sorted_results):
        print(f"\n--- Round 2 Trial {idx+1}/{len(sorted_results)} ---")
        print(f"Resuming from: {r1_result.sparrow_out_path} (side={r1_result.final_side:.4f})")
        
        trial_dir = round2_dir / f"trial_{idx}"
        trial_dir.mkdir(parents=True, exist_ok=True)
        
        # Resume sparrow for 60 more seconds
        sparrow_output = trial_dir / "sparrow_out.json"
        
        cmd = [
            str(SPARROW_BIN),
            "-i", str(r1_result.sparrow_out_path),
            "-e", str(sparrow_sec),
            "-c", "0"
        ]
        
        sparrow_success = False
        try:
            result = subprocess.run(cmd, capture_output=True, text=True,
                                    timeout=sparrow_sec + 120,
                                    cwd=str(SCRIPT_DIR))
            
            if sparrow_log.exists():
                with open(sparrow_log) as f:
                    log_content = f.read()
                if "feasible solution found!" in log_content:
                    sparrow_success = True
        except Exception as e:
            print(f"Sparrow failed: {e}")
        
        if not sparrow_success:
            print(f"Round 2 Trial {idx}: Sparrow did not find feasible solution")
            continue
        
        # Copy sparrow output
        sparrow_expected_output = SCRIPT_DIR / "output" / f"final_n{n}_sqpp.json"
        if sparrow_expected_output.exists():
            shutil.copy(sparrow_expected_output, sparrow_output)
        else:
            print(f"Warning: Sparrow output not found")
            continue
        
        # Run SA with 8 seeds, 10x iterations
        sa_output_dir = trial_dir / "sa_outputs"
        best_path, new_side = run_sa_parallel(
            sparrow_output, sa_output_dir,
            num_seeds=8, iterations=sa_iterations
        )
        
        print(f"Round 2 Trial {idx}: side={new_side:.4f}")
        
        if best_path:
            round2_results.append(TrialResult(
                r1_result.original_side, new_side,
                sparrow_output, best_path, True
            ))
    
    print(f"\nRound 2 complete: {len(round2_results)} successful trials")
    return round2_results


def run_round3(round2_results: List[TrialResult], n: int, results_dir: Path,
               top_k: int = 2, sa_iterations: int = 35000000) -> List[TrialResult]:
    """
    Round 3: Take top-k results from Round 2, run SA for super long (no sparrow).
    """
    print(f"\n{'='*60}")
    print(f"ROUND 3: Top {top_k} results, SA {sa_iterations} iter (no sparrow)")
    print(f"{'='*60}")
    
    # Sort by final_side and take top-k
    sorted_results = sorted(round2_results, key=lambda r: r.final_side)[:top_k]
    print(f"Selected {len(sorted_results)} best results from Round 2")
    
    round3_dir = results_dir / "round3"
    round3_dir.mkdir(parents=True, exist_ok=True)
    
    round3_results = []
    
    for idx, r2_result in enumerate(sorted_results):
        print(f"\n--- Round 3 Trial {idx+1}/{len(sorted_results)} ---")
        print(f"Resuming from: {r2_result.sparrow_out_path} (side={r2_result.final_side:.4f})")
        
        trial_dir = round3_dir / f"trial_{idx}"
        trial_dir.mkdir(parents=True, exist_ok=True)
        
        # Run SA with 8 seeds, super long iterations (resume from R2 sparrow output)
        sa_output_dir = trial_dir / "sa_outputs"
        best_path, new_side = run_sa_parallel(
            r2_result.sparrow_out_path, sa_output_dir,
            num_seeds=8, iterations=sa_iterations
        )
        
        print(f"Round 3 Trial {idx}: side={new_side:.4f}")
        
        if best_path:
            round3_results.append(TrialResult(
                r2_result.original_side, new_side,
                r2_result.sparrow_out_path, best_path, True
            ))
    
    print(f"\nRound 3 complete: {len(round3_results)} successful trials")
    return round3_results


def main():
    parser = argparse.ArgumentParser(
        description="Optimize n-tree packing using 3-round optimization"
    )
    parser.add_argument('--n', type=int, required=True,
                        help='Target n value (will use n-1 from pool)')
    parser.add_argument('--pool-dir', type=str, default='pool',
                        help='Pool directory containing candidate packings')
    parser.add_argument('--results-dir', type=str, default='optimization_results',
                        help='Directory to save all results')
    parser.add_argument('--bisection-steps', type=int, default=10,
                        help='Number of bisection steps in Round 1')
    parser.add_argument('--low', type=float, default=1.0333,
                        help='Lower bound for area_factor')
    parser.add_argument('--high', type=float, default=1.1,
                        help='Upper bound for area_factor')
    # Round-specific parameters
    parser.add_argument('--r1-sa-iter', type=int, default=350000,
                        help='Round 1: SA iterations per seed')
    parser.add_argument('--r1-sparrow-sec', type=int, default=10,
                        help='Round 1: Sparrow exploration seconds')
    parser.add_argument('--r2-top-k', type=int, default=16,
                        help='Round 2: Number of top results to use')
    parser.add_argument('--r2-sparrow-sec', type=int, default=60,
                        help='Round 2: Sparrow exploration seconds')
    parser.add_argument('--r2-sa-iter', type=int, default=3500000,
                        help='Round 2: SA iterations per seed')
    parser.add_argument('--r3-top-k', type=int, default=2,
                        help='Round 3: Number of top results to use')
    parser.add_argument('--r3-sa-iter', type=int, default=35000000,
                        help='Round 3: SA iterations per seed')
    
    args = parser.parse_args()
    
    # Setup paths
    pool_dir = SCRIPT_DIR / args.pool_dir
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    results_dir = SCRIPT_DIR / args.results_dir / f"n{args.n}_{timestamp}"
    results_dir.mkdir(parents=True, exist_ok=True)
    
    print(f"Pool directory: {pool_dir}")
    print(f"Results directory: {results_dir}")
    print(f"Target n: {args.n}")
    
    # Round 1: Bisection search
    best_af, round1_results = bisection_search(
        str(pool_dir), args.n, results_dir,
        low=args.low, high=args.high, steps=args.bisection_steps,
        sa_iterations=args.r1_sa_iter,
        sparrow_exploration_sec=args.r1_sparrow_sec
    )
    
    if not round1_results:
        print("ERROR: No successful results from Round 1!")
        return
    
    # Round 2: Extended optimization on top 16
    round2_results = run_round2(
        round1_results, args.n, results_dir,
        top_k=args.r2_top_k, sparrow_sec=args.r2_sparrow_sec,
        sa_iterations=args.r2_sa_iter
    )
    
    if not round2_results:
        print("WARNING: No successful results from Round 2, using Round 1 best")
        best_result = min(round1_results, key=lambda r: r.final_side)
    else:
        # Round 3: Super long SA on top 2
        round3_results = run_round3(
            round2_results, args.n, results_dir,
            top_k=args.r3_top_k, sa_iterations=args.r3_sa_iter
        )
        
    # Collect all final results for selecting best 2 (combine R1, R2 and R3)
    all_final_results = []
    if round1_results:
        all_final_results.extend(round1_results)
    if round2_results:
        all_final_results.extend(round2_results)
    if round3_results:
        all_final_results.extend(round3_results)
    
    final_results = sorted(all_final_results, key=lambda r: r.final_side)
    
    # Final output
    print(f"\n{'='*60}")
    print(f"OPTIMIZATION COMPLETE!")
    print(f"{'='*60}")
    print(f"Best area_factor: {best_af:.4f}")
    print(f"Best square_side: {final_results[0].final_side:.4f}")
    
    # Copy best 2 results to pool
    existing = list(pool_dir.glob(f"n{args.n}_*.json"))
    next_num = len(existing) + 1
    
    for i, result in enumerate(final_results[:2]):
        if result.best_sa_path and result.best_sa_path.exists():
            pool_output = pool_dir / f"n{args.n}_c{next_num + i}.json"
            shutil.copy(result.best_sa_path, pool_output)
            print(f"Result #{i+1} (side={result.final_side:.4f}) saved to pool: {pool_output}")
            
            # Also save to results dir
            final_output = results_dir / f"best_{i+1}_n{args.n}_side{result.final_side:.4f}.json"
            shutil.copy(result.best_sa_path, final_output)
    
    if not final_results or not final_results[0].best_sa_path:
        print("WARNING: No valid result found!")


if __name__ == "__main__":
    main()

