import math
import os
import random
import shapely
import tqdm
import json
from decimal import Decimal, getcontext

import functools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.patches import Rectangle
from shapely import affinity, touches, coverage_union_all, convex_hull, oriented_envelope
from shapely.geometry import Polygon, Point, MultiPolygon
from shapely.ops import unary_union
from numba import njit
from typing import Self
from scipy.optimize import minimize_scalar, bisect
from shapely.strtree import STRtree

# Set precision for Decimal
getcontext().prec = 25
scale_factor = Decimal('1e15')


class ChristmasTree:
    """Represents a single, rotatable Christmas tree of a fixed size."""

    def __init__(self, center_x='0', center_y='0', angle='0'):
        """Initializes the Christmas tree with a specific position and rotation."""
        self.center_x = Decimal(center_x)
        self.center_y = Decimal(center_y)
        self.angle = Decimal(angle)

        trunk_w = Decimal('0.15')
        trunk_h = Decimal('0.2')
        base_w = Decimal('0.7')
        mid_w = Decimal('0.4')
        top_w = Decimal('0.25')
        tip_y = Decimal('0.8')
        tier_1_y = Decimal('0.5')
        tier_2_y = Decimal('0.25')
        base_y = Decimal('0.0')
        trunk_bottom_y = -trunk_h

        initial_polygon = Polygon(
            [
                # Start at Tip
                (Decimal('0.0') * scale_factor, tip_y * scale_factor),
                # Right side - Top Tier
                (top_w / Decimal('2') * scale_factor, tier_1_y * scale_factor),
                (top_w / Decimal('4') * scale_factor, tier_1_y * scale_factor),
                # Right side - Middle Tier
                (mid_w / Decimal('2') * scale_factor, tier_2_y * scale_factor),
                (mid_w / Decimal('4') * scale_factor, tier_2_y * scale_factor),
                # Right side - Bottom Tier
                (base_w / Decimal('2') * scale_factor, base_y * scale_factor),
                # Right Trunk
                (trunk_w / Decimal('2') * scale_factor, base_y * scale_factor),
                (trunk_w / Decimal('2') * scale_factor, trunk_bottom_y * scale_factor),
                # Left Trunk
                (-(trunk_w / Decimal('2')) * scale_factor, trunk_bottom_y * scale_factor),
                (-(trunk_w / Decimal('2')) * scale_factor, base_y * scale_factor),
                # Left side - Bottom Tier
                (-(base_w / Decimal('2')) * scale_factor, base_y * scale_factor),
                # Left side - Middle Tier
                (-(mid_w / Decimal('4')) * scale_factor, tier_2_y * scale_factor),
                (-(mid_w / Decimal('2')) * scale_factor, tier_2_y * scale_factor),
                # Left side - Top Tier
                (-(top_w / Decimal('4')) * scale_factor, tier_1_y * scale_factor),
                (-(top_w / Decimal('2')) * scale_factor, tier_1_y * scale_factor),
            ]
        )
        self.initial_polygon = initial_polygon
        rotated = affinity.rotate(initial_polygon, float(self.angle), origin=(0, 0))
        self.polygon = affinity.translate(rotated,
                                          xoff=float(self.center_x * scale_factor),
                                          yoff=float(self.center_y * scale_factor))

    def plot(self, ax, color=None):
        if color is None:
            color = "C0"
        x, y = self.polygon.exterior.xy
        # Remove scale factor to get original coordinates
        x_scaled = [coord / float(scale_factor) for coord in x]
        y_scaled = [coord / float(scale_factor) for coord in y]
        ax.fill(x_scaled, y_scaled, alpha=0.8, fc=color, ec='black', linewidth=0.3)

    @functools.cached_property
    def area(self):
        return Decimal(self.polygon.area)/scale_factor/scale_factor

TREE_AREA = float(ChristmasTree().area)

TRUNK_W = 0.15
TRUNK_H = 0.2
BASE_W = 0.7
MID_W = 0.4
TOP_W = 0.25
TIP_Y = 0.8
TIER_1_Y = 0.5
TIER_2_Y = 0.25
BASE_Y = 0.0
TRUNK_BOTTOM_Y = -TRUNK_H

@njit(cache=True)
def rotate_point(x, y, cos_a, sin_a):
    return x * cos_a - y * sin_a, x * sin_a + y * cos_a

@njit(cache=True)
def get_tree_vertices(cx, cy, angle_deg):
    """Get 15 vertices of tree polygon at given position and angle."""
    angle_rad = angle_deg * math.pi / 180.0
    cos_a = math.cos(angle_rad)
    sin_a = math.sin(angle_rad)
    vertices = np.empty((15, 2), dtype=np.float64)
    # Define polygon points relative to center (0,0)
    pts = np.array([
        [0.0, TIP_Y], [TOP_W/2, TIER_1_Y], [TOP_W/4, TIER_1_Y],
        [MID_W/2, TIER_2_Y], [MID_W/4, TIER_2_Y], [BASE_W/2, BASE_Y],
        [TRUNK_W/2, BASE_Y], [TRUNK_W/2, TRUNK_BOTTOM_Y],
        [-TRUNK_W/2, TRUNK_BOTTOM_Y], [-TRUNK_W/2, BASE_Y],
        [-BASE_W/2, BASE_Y], [-MID_W/4, TIER_2_Y], [-MID_W/2, TIER_2_Y],
        [-TOP_W/4, TIER_1_Y], [-TOP_W/2, TIER_1_Y],
    ], dtype=np.float64)
    for i in range(15):
        rx, ry = rotate_point(pts[i, 0], pts[i, 1], cos_a, sin_a)
        vertices[i, 0] = rx + cx
        vertices[i, 1] = ry + cy
    return vertices

@njit(cache=True)
def vertices_post_rotation(seeds, global_rot):
    n_trees = seeds.shape[0]

    # rotation matrix
    rot_rad = global_rot * math.pi / 180.0
    cos_r = math.cos(rot_rad)
    sin_r = math.sin(rot_rad)
    rotm = np.array([[cos_r, sin_r, 0.0], [-sin_r, cos_r, 0.0], [0.0, 0.0, 1.0]], dtype=np.float64)

    # transform the center and angles
    seeds_t = seeds @ rotm
    # print(seeds_t.shape)
    seeds_t = seeds_t + np.array([0.0, 0.0, global_rot], dtype=np.float64)

    vertices = np.empty((n_trees* 15, 2), dtype=np.float64)
    for i in range(n_trees):
        vertices[15*i:15*(i+1)] = get_tree_vertices(seeds_t[i, 0], seeds_t[i, 1], seeds_t[i, 2])
    
    return vertices

@njit(cache=True)
def oriented_bbox_minx_miny(seeds, global_rot):
    vertices = vertices_post_rotation(seeds, global_rot)
    # bounds
    min_x, min_y, max_x, max_y = polygon_bounds(vertices)
    return min_x, min_y

@njit(cache=True)
def oriented_bbox(seeds, global_rot):
    vertices = vertices_post_rotation(seeds, global_rot)
    # bounds
    min_x, min_y, max_x, max_y = polygon_bounds(vertices)
    return (max_x - min_x), (max_y - min_y)

@njit(cache=True)
def oriented_bbox_area(seeds, global_rot):
    vertices = vertices_post_rotation(seeds, global_rot)
    # bounds
    min_x, min_y, max_x, max_y = polygon_bounds(vertices)
    rect_area = (max_y - min_y) * (max_x - min_x)
    if (max_y - min_y) > (max_x - min_x):
        w = max_y - min_y
    else:
        w = max_x - min_x
    sq_area = w**2
    # return rect_area
    return sq_area
    
@njit(cache=True)
def polygon_bounds(vertices):
    min_x = vertices[0, 0]
    min_y = vertices[0, 1]
    max_x = vertices[0, 0]
    max_y = vertices[0, 1]
    for i in range(1, vertices.shape[0]):
        x = vertices[i, 0]
        y = vertices[i, 1]
        if x < min_x: min_x = x
        if x > max_x: max_x = x
        if y < min_y: min_y = y
        if y > max_y: max_y = y
    return min_x, min_y, max_x, max_y

class Submission:
    """
    Represents a single submission
    """
    def __init__(self, csv_file, name="monke"):
        self.df = pd.read_csv(csv_file, dtype=str)
        self.name = name
        print(self.df.head())
        self.parse()
        
    def parse(self):
        _seeds = {}
        for _, row in tqdm.tqdm(self.df.iterrows(), desc = "Reading submission: "):
            # print(row)
            n = int(row["id"].split("_")[0])
            if n not in _seeds:
                # print(f"Found {n} tree package.")
                _seeds[n] = []
            seed = [row["x"].strip("s"), row["y"].strip("s"), row["deg"].strip("s")]
            _seeds[n].append(seed)
        packages = []
        for n in sorted(_seeds):
            packages.append(Package(_seeds[n]))
        self.packages = packages
    
    def write(self, filepath):
        rows = []
        idx = 0
        for n in range(1, 201):
            seeds = self.packages[n-1].seeds
            for t in range(n):
                cx, cy, angle = seeds[t]
                rows.append({
                    "id": f"{n:03d}_{t}",
                    "x": f"s{cx}",
                    "y": f"s{cy}",
                    "deg": f"s{angle}",
                })
                idx += 1
        pd.DataFrame(rows).to_csv(filepath, index=False)

    def add(self, package, better=True):
        n = package.n
        if better:
            cur_score = self.packages[n-1].bbox()
            new_score = package.bbox()
            if new_score[0] < cur_score[0]:
                self.packages[n-1] = package
            else:
                return
        else:
            self.packages[n-1] = package
    
    def score(self):
        """
        score(verb)
        
        :param self: Description
        """
        n_list, side_length_list, area_per_tree_list = [], [], []
        for package in self.packages:
            n = package.n
            side_length, area_per_tree = package.bbox()
            side_length_list.append(side_length)
            area_per_tree_list.append(area_per_tree)
            n_list.append(n)
        # print(n_list)
        area_per_tree_sum = sum(area_per_tree_list)
        print(f"Score = {area_per_tree_sum:.4f}")
        for i in range(20):
            print(f"{i*10+1}-{(i+1)*10}:{sum(area_per_tree_list[i*10:(i+1)*10]):.4f}")
        # print(f"1-20 = {sum(area_per_tree_list[:20]):.4f}\t21-60 = {sum(area_per_tree_list[20:60]):.4f}\t61-100 = {sum(area_per_tree_list[60:100]):.4f}\t101-150 = {sum(area_per_tree_list[100:150]):.4f}\t151-200 = {sum(area_per_tree_list[150:]):.4f}")
        details = (n_list, side_length_list, area_per_tree_list)
        return area_per_tree_sum, details

    def plot_scores(self, axs, fname="scores.png"):
        area_per_tree_sum, details = self.score()
        ax_was_None = False
        n_list, side_length_list, area_per_tree_list = details
        if axs is None:
            fig, axs = plt.subplots(1, 3, figsize=(30, 6))
            ax_was_None =True
        axs[0].plot(n_list, side_length_list, ".-", label = f"{self.name} x {area_per_tree_sum:.3f}")
        area_per_tree_list = np.array(area_per_tree_list, dtype=np.float64)
        axs[1].plot(n_list, area_per_tree_list, ".-", label = f"{self.name} x {area_per_tree_sum:.3f}")
        axs[1].plot(n_list, (np.cumsum(area_per_tree_list[::-1])/np.array(n_list))[::-1], ".-", color="k")
        # odd even
        axs[2].plot(n_list[::2], area_per_tree_list[::2], ".-", label = f"{self.name} (even) x {sum(area_per_tree_list[::2]):.3f}")
        axs[2].plot(n_list[1::2], area_per_tree_list[1::2], ".-", label = f"{self.name} (odd) x {sum(area_per_tree_list[1::2]):.3f}")
        if ax_was_None:
            fig.savefig(fname, bbox_inches="tight", dpi=300)
    
    def plot_all(self, axs=None, fname="submission.png"):
        if axs is None:
            fig, axs = plt.subplots(20, 10, figsize=(20, 40))
        for package in tqdm.tqdm(self.packages, desc = "Plotting: "):
            n = package.n
            # print(f"{n}")
            ax = axs[(n-1)//10, (n-1)%10]
            package.plot(ax)
        fig.savefig(fname, bbox_inches="tight", dpi=600)

    def write_scores(self):
        area_per_tree_sum, details = self.score()
        n_list, side_length_list, area_per_tree_list = details
        scdf = pd.DataFrame({"n":n_list, "score":area_per_tree_list})
        print(scdf.head())
        _sccsv_name = self.name.replace(".csv", "_scores.csv")
        print(_sccsv_name)
        scdf.to_csv(_sccsv_name)

    def plot_selection(self, indices, axs=None, fname="submission.png"):
        if axs is None:
            n_rows = math.ceil(len(indices)/10)
            fig, axs = plt.subplots(n_rows, 10, figsize=(20, n_rows*2))
        count = 0
        for package in self.packages:
            if package.n in indices:
                ax = axs[count//10, count%10]
                package.plot(ax)
                count += 1
        fig.savefig(fname, bbox_inches="tight", dpi=300)


class Package:
    """
    Represents a single package of n christmas trees
    """
    def __init__(self, seeds):
        self.seeds = np.array(seeds)

    def has_overlap(self) -> bool:
        """Check if any two ChristmasTree polygons overlap."""

        trees = self.trees
        if len(trees) <= 1:
            return False

        polygons = [t.polygon for t in trees]
        # Use STRtree for efficient proximity queries (optimizes checking pairs)
        tree_index = STRtree(polygons)

        for i, poly in enumerate(polygons):
            # Query for polygons whose bounding boxes overlap with poly
            # This returns the indices of potential overlaps
            indices = tree_index.query(poly)
            
            for idx in indices:
                # Skip checking the polygon against itself
                if idx == i:
                    continue
                    
                # Perform the precise intersection check
                if poly.intersects(polygons[idx]) and not poly.touches(polygons[idx]):
                    # Overlap found!
                    return True
        return False

    def blow_up_tp_sparrow_validity(self):
        f32_package = Package(self.seeds)
        f = 1
        while Package(np.round(f32_package.seeds, decimals=7)).has_overlap():
            f *= 1.0001
            f32_package = Package(self.seeds*[f, f, 1])
        print(f"scale factor at no f32 overlap = {f:.7e}")
        return f32_package, f

    @classmethod
    def from_sparrow(cls, sparrow_json_file):
        results = json.load(open(sparrow_json_file))
        seeds = []
        for x in results["solution"]["layout"]["placed_items"]:
            seeds.append([*x["transformation"]["translation"], x["transformation"]["rotation"]])
        return cls(seeds)

    def to_sparrow(self, op_json_filename, sparrow_json_file_template="sqpp_result.json", edge_gap=1e-4):
        placed_items = []

        minx, miny = oriented_bbox_minx_miny(self.seeds, 0)
        minx -= edge_gap
        miny -= edge_gap
        for seed in self.seeds:
            entry = {
                "item_id": 0,
                "transformation": {
                    "rotation": seed[2],
                    "translation": [
                    seed[0] - minx,
                    seed[1] - miny
                    ]
                }
            }
            placed_items.append(entry)

        json_tmp = json.load(open(sparrow_json_file_template))
        json_tmp["solution"]["layout"]["placed_items"] = placed_items
        json_tmp["name"] = f"n{self.n}_sqpp"
        json_tmp["items"][0]["demand"] = self.n
        json_tmp["solution"]["square_side"] = max(oriented_bbox(self.seeds, 0)) + 2 * edge_gap
        with open(op_json_filename, "w") as fout:
            json.dump(json_tmp, fout, indent=4)

    def blow_up(self, area_factor):
        cur_area = max(oriented_bbox(self.seeds, 0))**2
        f = lambda x:area_factor-max(oriented_bbox(self.seeds * [x, x, 1], 0))**2/cur_area
        length_factor = bisect(f, area_factor**0.5, 2*area_factor**0.5)
        # print(f(length_factor), max(oriented_bbox(self.seeds * [length_factor, length_factor, 1], 0))**2, cur_area)
        return Package(self.seeds * [length_factor, length_factor, 1])

    def to_sparrow_with_freeze(self, op_json_filename, freeze, sparrow_json_file_template="/home/vineet/Hobbies/ml/kaggle/santa2025/spyrrow/sqpp_result.json"):
        placed_items = []

        freeze_poly = Polygon(freeze)

        frozen_poly_seeds = []
        for seed in self.seeds:
            point = Point(seed[0], seed[1])
            if freeze_poly.contains(point):
                frozen_poly_seeds.append(seed)
                continue

            entry = {
                "item_id": 0,
                "transformation": {
                    "rotation": seed[2],
                    "translation": [
                    seed[0],
                    seed[1]
                    ]
                }
            }
            placed_items.append(entry)

        frozen_poly_n = len(frozen_poly_seeds)

        frozen_block_pkg = Package(frozen_poly_seeds)
        frozen_block_vertices = frozen_block_pkg.compute_external_boundary(buffer_dist=1e-3, simplify_tolerance=0)

        entry = {
            "item_id": 1,
            "transformation": {
                "rotation": 0,
                "translation": [
                0,
                0
                ]
            }
        }
        placed_items.append(entry)

        blk_shape_entry =     {
            "id": 1,
            "shape": {
                "type": "simple_polygon",
                "data": frozen_block_vertices.tolist()
            },
            "min_quality": "null",
            "demand": 1
        }

        json_tmp = json.load(open(sparrow_json_file_template))
        json_tmp["solution"]["layout"]["placed_items"] = placed_items
        json_tmp["name"] = f"n{self.n}_sqpp"
        json_tmp["items"][0]["demand"] = self.n - frozen_poly_n
        json_tmp["items"].append(blk_shape_entry)
        json_tmp["solution"]["square_side"] = max(oriented_bbox(self.seeds, 0))
        with open(op_json_filename, "w") as fout:
            json.dump(json_tmp, fout, indent=4)

    def to_sparrow2(self, op_json_filename, sparrow_json_file_template="/home/vineet/Hobbies/ml/kaggle/santa2025/spyrrow/sqpp_result.json", block=None):
        placed_items = []

        blk_n = 0
        if block:
            blk_minx, blk_miny = oriented_bbox_minx_miny(block.seeds, 0)
            blk_n = block.n
            entry = {
                "item_id": 1,
                "transformation": {
                    "rotation": 0,
                    "translation": [
                    - blk_minx,
                    - blk_miny
                    ]
                }
            }
            placed_items.append(entry)

            block_poly_vertices = block.compute_external_boundary(buffer_dist=1e-6, simplify_tolerance=0)
            block_poly = Polygon(block_poly_vertices)

            blk_shape_entry =     {
                "id": 1,
                "shape": {
                    "type": "simple_polygon",
                    "data": block_poly_vertices.tolist()
                },
                "min_quality": "null",
                "demand": 1
            }

        minx, miny = oriented_bbox_minx_miny(self.seeds, 0)
        minx = min(0, minx)
        miny = min(0, miny)
        for seed in self.seeds:
            if block:
                point = Point(seed[0], seed[1])
                if block_poly.contains(point):
                    continue

            entry = {
                "item_id": 0,
                "transformation": {
                    "rotation": seed[2],
                    "translation": [
                    seed[0] - minx,
                    seed[1] - miny
                    ]
                }
            }
            placed_items.append(entry)

        json_tmp = json.load(open(sparrow_json_file_template))
        json_tmp["solution"]["layout"]["placed_items"] = placed_items
        json_tmp["name"] = f"n{self.n}_sqpp"
        json_tmp["items"][0]["demand"] = self.n - blk_n
        if block:
            json_tmp["items"].append(blk_shape_entry)
        json_tmp["solution"]["square_side"] = max(oriented_bbox(self.seeds, 0))
        with open(op_json_filename, "w") as fout:
            json.dump(json_tmp, fout, indent=4)

    def to_sparrow_as_block(
            self, op_json_filename,
            sparrow_json_file_template="/home/vineet/Hobbies/ml/kaggle/santa2025/spyrrow/santa_spp_template.json",
            n_extra_trees=16
    ):
        json_tmp = json.load(open(sparrow_json_file_template))
        id_max = max([x["id"] for x in json_tmp["items"]])
        # add the block
        block_entry = {
            "id": id_max+1,
            "demand": 1,
            "shape": {
                "type": "simple_polygon",
                "seeds": self.seeds.tolist()
            }
        }
        block_entry["shape"]["data"] = self.compute_external_boundary().tolist()
        json_tmp["items"].append(block_entry)
        # edit the extra trees (assuming its always first item in the json_tmp["items"] list)
        json_tmp["items"][0]["demand"] = n_extra_trees
        # edit the height if spp
        block_height = oriented_bbox(self.seeds, 0)[1]
        if "strip_height" in json_tmp:
            json_tmp["strip_height"] = block_height + 0.01
        with open(op_json_filename, "w") as fout:
            json.dump(json_tmp, fout, indent=4)

    def compute_external_boundary(self, buffer_dist: float = 1e-4, simplify_tolerance: float = 1e-4) -> np.ndarray:
        """
        Compute the external boundary of a polygon arrangement.
        
        Args:
            buffer_dist: Distance to buffer/unbuffer. Adjust based on gap size between polygons.
                        Smaller = closer to actual shapes, larger = smoother boundary.
            simplify_tolerance: Tolerance for Douglas-Peucker simplification to remove redundant
                            collinear points. Set to 0 to disable simplification.
        
        Returns:
            Array of shape (n_points, 2) representing the external boundary coordinates.
        """
        # Convert to Shapely polygons with small buffer to merge them
        polygons_arr = self.vertices
        shapely_polygons = []
        for i in range(len(polygons_arr)):
            poly = Polygon(polygons_arr[i])
            if poly.is_valid:
                shapely_polygons.append(poly.buffer(buffer_dist))
            else:
                # Fix invalid polygon and then buffer
                shapely_polygons.append(poly.buffer(0).buffer(buffer_dist))
        
        # Union all buffered polygons
        union_result = unary_union(shapely_polygons)
        
        # Shrink back to approximate original size
        shrunk = union_result.buffer(-buffer_dist)
        
        # Get ONLY the exterior (ignore any interior holes)
        exterior_only = Polygon(shrunk.exterior)
        
        # Simplify to remove redundant collinear points
        if simplify_tolerance > 0:
            exterior_only = exterior_only.simplify(simplify_tolerance, preserve_topology=True)
        
        return np.array(exterior_only.exterior.coords)

    @functools.cached_property
    def vertices(self):
        _vertices = []
        for seed in self.seeds.astype(float):
            _vertices.append(get_tree_vertices(cx=seed[0], cy=seed[1], angle_deg=seed[2]))
        return np.array(_vertices)

    @functools.cached_property
    def trees(self):
        _trees = []
        for center_x, center_y, angle in self.seeds:
            _trees.append(ChristmasTree(center_x=str(center_x), center_y=str(center_y), angle=str(angle)))
        return _trees
    
    @property
    def n(self):
        return len(self.seeds)

    def bbox(self, ax=None):
        all_polygons = [t.polygon for t in self.trees]
        bounds = unary_union(all_polygons).bounds
        minx = Decimal(bounds[0]) / scale_factor
        miny = Decimal(bounds[1]) / scale_factor
        maxx = Decimal(bounds[2]) / scale_factor
        maxy = Decimal(bounds[3]) / scale_factor

        width = maxx - minx
        height = maxy - miny

        self.bbox_width = width
        self.bbox_height = height

        side_length = max(width, height)

        square_x = minx if width >= height else minx - (side_length - width) / 2
        square_y = miny if height >= width else miny - (side_length - height) / 2

        area_per_tree = side_length**2 / self.n

        if ax is not None:
            bounding_square = Rectangle(
                (float(square_x), float(square_y)),
                float(side_length),
                float(side_length),
                fill=False,
                edgecolor='red',
                linewidth=0.3,
                linestyle='-',
            )
            ax.add_patch(bounding_square)
        
        return side_length, area_per_tree

    def plot(self, ax = None, cmap = None):
        ax_was_None = False
        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=(2, 2))
            ax_was_None = True
        if cmap is None:
            cmap = plt.get_cmap('tab20')
        for i_tree, tree in enumerate(self.trees):
            color = cmap(i_tree%20)
            ax.set_axis_off()
            tree.plot(ax, color)
        side_length, score = self.bbox(ax)
        ax.set_title(f"n = {self.n}, s = {side_length:.4f}, sc = {score:.4f}", fontsize=6)
        if ax_was_None:
            return fig, ax
    
    def convex_hull(self, ax=None):
        all_polygons = [t.polygon for t in self.trees]
        hull = MultiPolygon(all_polygons).convex_hull
        
        if ax is not None:
            x, y = hull.exterior.xy
            # Remove scale factor to get original coordinates
            x_scaled = [coord / float(scale_factor) for coord in x]
            y_scaled = [coord / float(scale_factor) for coord in y]
            ax.fill(x_scaled, y_scaled, alpha=0.8, fc='none', ec='black', linewidth=0.3)

        return hull

    def concave_hull(self, ax=None):
        all_polygons = [t.polygon for t in self.trees]
        concave_hull = shapely.normalize(shapely.concave_hull(MultiPolygon(all_polygons), ratio=0.1))

        if ax is not None:
            x, y = concave_hull.exterior.xy
            # Remove scale factor to get original coordinates
            x_scaled = [coord / float(scale_factor) for coord in x]
            y_scaled = [coord / float(scale_factor) for coord in y]
            ax.fill(x_scaled, y_scaled, alpha=0.8, fc='none', ec='black', linewidth=0.3)

        return concave_hull

    @property
    def convex_hull_area(self):
        return Decimal(self.convex_hull().area)/scale_factor/scale_factor

    @functools.cached_property
    def convex_hull_area_per_tree(self):
        return self.convex_hull_area/self.n
    
    def oriented_envelope(self, ax=None):
        f = lambda x:oriented_bbox_area(self.seeds, x)
        res = minimize_scalar(f, bounds=(-5, 5), method='bounded')
        angle = res.x
        self.obb_area = res.fun
        self.obb_area_per_tree = self.obb_area/self.n
        self.obb_angle = angle
        self.obb_lenx, self.obb_leny = oriented_bbox(self.seeds, angle)
        self.obb_ratio = self.obb_leny/self.obb_lenx
        if ax is not None:
            ax.plot([0, math.cos(math.radians(angle))], [0, -math.sin(math.radians(angle))], color="black")
            ax.plot([0, math.sin(math.radians(angle))], [0, math.cos(math.radians(angle))], color="black")

class Utils:        
    def compare_scores(submissions, fname="score_comparison.png"):
        fig, axs = plt.subplots(1, 3, figsize=(96, 18))
        for everyone in submissions:
            everyone.plot_scores(axs)
        for i in [0, 1, 2]:
            axs[i].legend(fontsize = 20, loc = "upper center")
            axs[i].set_xticks(range(0, 201, 5))
            axs[i].grid(True)
        axs[0].xaxis.tick_top()
        for i in [1, 2]:
            axs[i].set_yticks(np.arange(0.325,0.4,0.005))
            axs[i].set_ylim(0.325, 0.4)
            # Create twin axis for the right side
            ax2 = axs[i].twinx()
            ax2.set_ylim(axs[i].get_ylim()) # Keep scales identical
            ax2.set_yticks(np.arange(0.325,0.4,0.005))
        fig.savefig(fname, bbox_inches="tight", dpi=300)
    
    # def merge():

def get_line_segments(obj):
    if isinstance(obj, ChristmasTree):
        obj_coords = np.array(obj.polygon.exterior.coords) / float(scale_factor)
        obj_line_segments = np.hstack([obj_coords, np.roll(obj_coords, shift=-1, axis=0)])
    elif isinstance(obj, Package):
        obj_line_segments = []
        for tree in obj.trees:
            obj_coords = np.array(tree.polygon.exterior.coords) / float(scale_factor)
            obj_line_segments.append(np.hstack([obj_coords, np.roll(obj_coords, shift=-1, axis=0)]))
        obj_line_segments = np.vstack(obj_line_segments)
    elif isinstance(obj, np.ndarray) and len(obj.shape) == 2:
        obj_coords = obj
        obj_line_segments = np.hstack([obj_coords, np.roll(obj_coords, shift=-1, axis=0)])
    elif isinstance(obj, np.ndarray) and len(obj.shape) == 3:
        # print(obj.shape)
        obj_line_segments = []
        for obj_coords in obj:
            obj_line_segments.append(np.hstack([obj_coords, np.roll(obj_coords, shift=-1, axis=0)]))
        # print(obj_line_segments[0].shape, obj_line_segments[1].shape)
        obj_line_segments = np.vstack(obj_line_segments)
    return obj_line_segments

def calc_min_distance(obj0, obj1, axis):
    obj0_line_segments = get_line_segments(obj0)
    obj1_line_segments = get_line_segments(obj1)
    # print(obj0_line_segments.shape, obj1_line_segments.shape)
    # print(obj0_line_segments)
    # print(obj1_line_segments)
    # assert False
    return min(calc_min_distance_axis_numba(obj0_line_segments, obj1_line_segments, axis), calc_min_distance_axis_numba(obj1_line_segments, obj0_line_segments, axis))

# def _calc_min_distance_x(obj0_coords, obj1_coords):
#     min_xd = 1e30
#     for i in range(len(obj0_coords)-1):
#         x1, y1 = obj0_coords[i]
#         x2, y2 = obj0_coords[i+1]
#         for j in range(len(obj1_coords)):
#             x, y = obj1_coords[j]
#             # print(x1, y1, x2, y2, x, y)
#             if y >= min(y1, y2) and y <= max(y1, y2):
#                 if y2 == y1:
#                     xd = min(abs(x1 - x), abs(x2 - x))
#                 else:
#                     xd = abs(x1 + ((x2 - x1) / (y2 - y1)) * (y - y1) - x)
#                 if xd < min_xd:
#                     min_xd = xd
#     return min_xd

@njit(fastmath=True, cache=True)
def calc_min_distance_axis_numba(obj0_line_segments, obj1_line_segments, axis):
    """
    axis = 0 -> x-distance
    axis = 1 -> y-distance
    """
    min_d = 1e30

    n0 = obj0_line_segments.shape[0]
    n1 = obj1_line_segments.shape[0]

    # print(axis)
    # print(">>", min_d)

    for i in range(n0):
        # Select coordinate orientation
        if axis == 0:
            a1 = obj0_line_segments[i, 0]     # x1
            b1 = obj0_line_segments[i, 1]     # y1
            a2 = obj0_line_segments[i, 2] # x2
            b2 = obj0_line_segments[i, 3] # y2
        else:
            a1 = obj0_line_segments[i, 1]     # y1
            b1 = obj0_line_segments[i, 0]     # x1
            a2 = obj0_line_segments[i, 3] # y2
            b2 = obj0_line_segments[i, 2] # x2

        bmin = b1 if b1 < b2 else b2
        bmax = b2 if b2 > b1 else b1

        for j in range(n1):
            if axis == 0:
                a = obj1_line_segments[j, 0]
                b = obj1_line_segments[j, 1]
            else:
                a = obj1_line_segments[j, 1]
                b = obj1_line_segments[j, 0]

            # range check
            if b < bmin or b > bmax:
                continue

            # "horizontal" in chosen coordinate system
            if b1 == b2:
                d1 = abs(a1 - a)
                d2 = abs(a2 - a)
                d = d1 if d1 < d2 else d2
            else:
                d = abs(
                    a1 + (a2 - a1) * (b - b1) / (b2 - b1) - a
                )
            if d < min_d:
                # print(">>", d, i, j)
                min_d = d

    return min_d

# def calc_min_distance_y(obj0, obj1):
#     min_yd = 1e30
#     coords = []
#     for obj in [obj0, obj1]:
#         if isinstance(obj, ChristmasTree):
#             obj_coords = obj.polygon.exterior.coords
#         elif isinstance(obj, Package):
#             obj_coords = []
#             for tree in obj.trees:
#                 obj_coords += tree.polygon.exterior.coords
#         coords.append(obj_coords)
#     obj0_coords, obj1_coords = coords
#     for i in range(len(obj0_coords)-1):
#         x1, y1 = obj0_coords[i]
#         x2, y2 = obj0_coords[i+1]
#         for j in range(len(obj1_coords)):
#             x, y = obj1_coords[j]
#             # print(x1, y1, x2, y2, x, y)
#             if x >= min(x1, x2) and x <= max(x1, x2):
#                 if x2 == x1:
#                     yd = min(abs(y1 - y), abs(y2 - y))
#                 else:
#                     yd = abs(y1 + ((y2 - y1) / (x2 - x1)) * (x - x1) - y)
#                 if yd < min_yd:
#                     min_yd = yd
#     return min_yd / float(scale_factor)

def unfreeze(inp_json_file, op_json_file, unfrozen_json_file):
    inp_json = json.load(open(inp_json_file))
    seeds = inp_json["items"][1]["shape"]["seeds"]    # assmuing block to be the second item with item_id = 1

    j = json.load(open(op_json_file))

    idx_pop = []
    for i, x in enumerate(j["solution"]["layout"]["placed_items"]):
        if x["item_id"] == 1:    # assmuing block to be the second item with item_id = 1
            print(x)
            trans_x, trans_y = x["transformation"]["translation"]
            rot = x["transformation"]["rotation"]
            cos_a = np.cos(np.radians(rot))
            sin_a = np.sin(np.radians(rot))
            for seed in seeds:
                cx, cy = seed[0], seed[1]
                cx_post_rot, cy_post_rot = rotate_point(cx, cy, cos_a, sin_a)
                cx_post_rot_trans = cx_post_rot + trans_x
                cy_post_rot_trans = cy_post_rot + trans_y
                entry = {'item_id': 0, 'transformation': {'rotation': float(seed[2] + rot), 'translation': [cx_post_rot_trans, cy_post_rot_trans]}}
                print(entry)
                j["solution"]["layout"]["placed_items"].append(entry)
            idx_pop.append(i)

    print(idx_pop)
    for i, idx in enumerate(idx_pop):
        j["solution"]["layout"]["placed_items"].pop(idx-i)

    idx_pop = []
    for i, x in enumerate(j["items"]):
        if x["id"] == 1:
            idx_pop.append(i)

    print(idx_pop)
    for i, idx in enumerate(idx_pop):
        j["items"].pop(idx-i)

    with open(unfrozen_json_file, "w") as fout:
        json.dump(j, fout, indent=4)