from typing import IO
from dae.genomic_resources.histogram import CategoricalHistogram
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("agg")


def plot_rightlabels(
    outfile: IO,
    histogram: CategoricalHistogram,
    xlabel: str,
    _small_values_description: str | None = None,
    _large_values_description: str | None = None,
) -> None:
    """Plot histogram and save it into outfile."""
    
    # Define a mapping of columns to the number of bars to display
    column_bars_mapping = {
        'CLNDN': 10,
        'CLNDNINCL': 15,
        'CLNDISDB': 15,
        'CLNDISDBINCL': 15,
        'CLNREVSTAT': 10,
        'CLNSIG': 20,
        'CLNSIGCONF': 15,
        'CLNSIGINCL': 15,
        'CLNVC': 8,
        'CLNVCSO': 10,
        'CLNVI': 10,
        'GENEINFO': 20,
        'MC': 15,
        'ONCDN': 5,
        'ONCREVSTAT': 5,
        'RS': 20,
        'SCIDN': 10,
        'SCIDISDB': 15,
        'SCI': 10,
        'SCIREVSTAT': 10
    }

    # Get the number of bars based on the column name or default to 10
    noofbars = column_bars_mapping.get(xlabel, 10)

    # Sort values by counts in descending order
    values = list(sorted(histogram.raw_values.items(), key=lambda x: -x[1]))

    # Split values into top 'noofbars' and the rest
    top_values = values[:noofbars]
    other_values = values[noofbars:]

    # Combine the rest into an "Other" category if there are more categories
    if other_values:
        other_count = sum(v[1] for v in other_values)
        top_values.append(("Other", other_count))

    # Extract labels and counts for plotting
    labels = [v[0] for v in top_values]
    counts = [v[1] for v in top_values]

    # Adjust figure size to be more rectangular
    plt.figure(figsize=(12, 6))  # Width and height adjusted for rectangle shape
    _, ax = plt.subplots()

    # Assign numerical labels for x-axis
    numerical_labels = list(range(1, len(labels) + 1))

    # Plot bars with numerical labels on x-axis
    ax.bar(
        x=numerical_labels,
        height=counts,
        log=histogram.config.y_log_scale,
        align="center"  # Align bars with tick labels in the center
    )

    # Correct x-tick positions to align with bars
    ax.set_xticks(numerical_labels)
    ax.set_xticklabels(numerical_labels, rotation=0, ha='center')

    # Evenly space the labels on the right side from top to bottom
    total_labels = len(labels)
    for i, label in enumerate(labels):
        # Trim label if it exceeds 30 characters
        trimmed_label = label[:30] + '...' if len(label) > 30 else label
        plt.text(
            1.05, 1 - (i / total_labels),  # Evenly space from top to bottom
            f"{i + 1}: {trimmed_label}",
            horizontalalignment='left',
            verticalalignment='top',
            transform=ax.transAxes,
            fontsize=10
        )

    plt.xlabel(f"\nIvan+Murat {xlabel}")
    plt.ylabel("Ivan count")
    plt.tight_layout()

    # Save the plot
    plt.savefig(outfile, bbox_inches='tight')
    plt.clf()




import matplotlib.pyplot as plt

def plot_bottomlabels(
    outfile: IO,
    histogram: CategoricalHistogram,
    xlabel: str,
    _small_values_description: str | None = None,
    _large_values_description: str | None = None,
) -> None:
    """Plot histogram and save it into outfile."""

    # Define a mapping of columns to the number of bars to display
    column_bars_mapping = {
        'CLNDN': 10,
        'CLNDNINCL': 15,
        'CLNDISDB': 15,
        'CLNDISDBINCL': 15,
        'CLNREVSTAT': 10,
        'CLNSIG': 20,
        'CLNSIGCONF': 15,
        'CLNSIGINCL': 15,
        'CLNVC': 8,
        'CLNVCSO': 10,
        'CLNVI': 10,
        'GENEINFO': 20,
        'MC': 15,
        'ONCDN': 5,
        'ONCREVSTAT': 5,
        'RS': 20,
        'SCIDN': 10,
        'SCIDISDB': 15,
        'SCI': 10,
        'SCIREVSTAT': 10
    }

    # Get the number of bars based on the column name or default to 10
    noofbars = column_bars_mapping.get(xlabel, 10)

    # Sort values by counts in descending order
    values = list(sorted(histogram.raw_values.items(), key=lambda x: -x[1]))

    # Split values into top 'noofbars' and the rest
    top_values = values[:noofbars]
    other_values = values[noofbars:]

    # Combine the rest into an "Other" category if there are more categories
    if other_values:
        other_count = sum(v[1] for v in other_values)
        top_values.append(("Other", other_count))

    # Extract labels and counts for plotting
    labels = [v[0] for v in top_values]
    counts = [v[1] for v in top_values]

    # Adjust figure size to be more rectangular
    plt.figure(figsize=(12, 6))  # Width and height adjusted for rectangle shape
    _, ax = plt.subplots()

    # Plot bars with labels on the x-axis
    ax.bar(
        x=range(len(labels)),
        height=counts,
        log=histogram.config.y_log_scale,
        align="center"
    )

    # Set labels on the x-axis instead of numbers, and rotate them by 45 degrees
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='right')

    plt.xlabel(f"\n{xlabel}")
    plt.ylabel("count")
    plt.tight_layout()

    # Save the plot
    plt.savefig(outfile, bbox_inches='tight')
    plt.clf()




def plot_origin(
    outfile: IO,
    histogram: CategoricalHistogram,
    xlabel: str,
    _small_values_description: str | None = None,
    _large_values_description: str | None = None,
) -> None:
    """Plot histogram and save it into outfile."""

    # Mapping dictionary based on bit encoding
    bit_mapping = {
            1: "germline",
        2: "somatic",
        4: "inherited",
        8: "paternal",
        16: "maternal",
        32: "de-novo",
        64: "biparental",
        128: "uniparental",
        256: "not-tested",
        512: "tested-inconclusive"
        }

    def decode_value(value):
        names = []
        try:
            value = int(value)
            for bit, name in bit_mapping.items():
                if value & bit:
                    names.append(name)
        except ValueError:
            names.append("unknown")  # Handle any unexpected non-integer values
        return "+".join(names) if names else "unknown"


    # Get the number of bars based on the column name or default to 10
    noofbars = 10

    # Sort values by counts in descending order
    values = list(sorted(histogram.raw_values.items(), key=lambda x: -x[1]))

    # Split values into top 'noofbars' and the rest
    top_values = values[:noofbars]
    other_values = values[noofbars:]

    # Combine the rest into an "Other" category if there are more categories
    if other_values:
        other_count = sum(v[1] for v in other_values)
        top_values.append(("Other", other_count))

    # Extract labels and counts for plotting, decoding the labels
# Extract labels and counts for plotting, decoding the labels unless it's "."
# Extract labels and counts for plotting, decoding the labels unless it's "." or "Other"
    labels = ["None" if v[0] == "None" else v[0] if v[0] == "Other" else decode_value(v[0]) for v in top_values]

    counts = [v[1] for v in top_values]

    # Adjust figure size to be more rectangular
    plt.figure(figsize=(12, 6))  # Width and height adjusted for rectangle shape
    _, ax = plt.subplots()

    # Plot bars with labels on the x-axis
    ax.bar(
        x=range(len(labels)),
        height=counts,
        log=histogram.config.y_log_scale,
        align="center"
    )

    # Set labels on the x-axis instead of numbers, and rotate them by 45 degrees
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels, rotation=45, ha='right')

    plt.xlabel(f"\n{xlabel}")
    plt.ylabel("count")
    plt.tight_layout()

    # Save the plot
    plt.savefig(outfile, bbox_inches='tight')
    plt.clf()



def plot_empty(
    outfile: IO,
    histogram: CategoricalHistogram,
    xlabel: str,
    _small_values_description: str | None = None,
    _large_values_description: str | None = None,
) -> None:
    """Plot histogram and save it into outfile."""
#    
#    # Define a mapping of columns to the number of bars to display
#    column_bars_mapping = {
#        'CLNDN': 10,
#        'CLNDNINCL': 15,
#        'CLNDISDB': 15,
#        'CLNDISDBINCL': 15,
#        'CLNREVSTAT': 10,
#        'CLNSIG': 20,
#        'CLNSIGCONF': 15,
#        'CLNSIGINCL': 15,
#        'CLNVC': 8,
#        'CLNVCSO': 10,
#        'CLNVI': 10,
#        'GENEINFO': 20,
#        'MC': 15,
#        'ONCDN': 5,
#        'ONCREVSTAT': 5,
#        'RS': 20,
#        'SCIDN': 10,
#        'SCIDISDB': 15,
#        'SCI': 10,
#        'SCIREVSTAT': 10
#    }
#
#    # Get the number of bars based on the column name or default to 10
#    noofbars = column_bars_mapping.get(xlabel, 10)
#
#    # Sort values by counts in descending order
#    values = list(sorted(histogram.raw_values.items(), key=lambda x: -x[1]))
#
#    # Split values into top 'noofbars' and the rest
#    top_values = values[:noofbars]
#    other_values = values[noofbars:]
#
#    # Combine the rest into an "Other" category if there are more categories
#    if other_values:
#        other_count = sum(v[1] for v in other_values)
#        top_values.append(("Other", other_count))
#
#    # Extract labels and counts for plotting
#    labels = [v[0] for v in top_values]
#    counts = [v[1] for v in top_values]

    # Adjust figure size to be more rectangular
    plt.figure(figsize=(12, 6))  # Width and height adjusted for rectangle shape
#    _, ax = plt.subplots()
#
#    # Assign numerical labels for x-axis
#    numerical_labels = list(range(1, len(labels) + 1))
#
#    # Plot bars with numerical labels on x-axis
#    ax.bar(
#        x=numerical_labels,
#        height=counts,
#        log=histogram.config.y_log_scale,
#        align="center"  # Align bars with tick labels in the center
#    )
#
#    # Correct x-tick positions to align with bars
#    ax.set_xticks(numerical_labels)
#    ax.set_xticklabels(numerical_labels, rotation=0, ha='center')
#
#    # Evenly space the labels on the right side from top to bottom
#    total_labels = len(labels)
#    for i, label in enumerate(labels):
#        # Trim label if it exceeds 30 characters
#        trimmed_label = label[:30] + '...' if len(label) > 30 else label
#        plt.text(
#            1.05, 1 - (i / total_labels),  # Evenly space from top to bottom
#            f"{i + 1}: {trimmed_label}",
#            horizontalalignment='left',
#            verticalalignment='top',
#            transform=ax.transAxes,
#            fontsize=10
#        )
#
#    plt.xlabel(f"\nIvan {xlabel}")
#    plt.ylabel("Ivan count")
#    plt.tight_layout()

    # Save the plot
    plt.savefig(outfile, bbox_inches='tight')
    plt.clf()