# SPDX-License-Identifier: MIT
# Copyright (C) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation

from dataclasses import dataclass
import argparse
import fnmatch
import itertools
from pathlib import Path
from typing import List, Optional

GEN_DIR = ""  # in Cmake, have to generate files in same folder

BWD_DTYPE_MAP = {"fp16": "FmhaBwdFp16", "bf16": "FmhaBwdBf16"}

MODE_MAP = {"batch": "false", "group": "true"}

BOOL_MAP = {"t": "true", "f": "false"}

FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
// auto generated by fmha_bwd_pre_post_kernel_generate.py
// TODO: Remove this file, directly generate kernel from ck
#include "fmha_bwd.hpp"
"""


# GEMM0: Q@K=S^T
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
# Is it necessary to distinguish between K0~K4?
@dataclass
class FmhaBwdDQDKDVTileSize:
    F_bm0: int  # tile size along q seqlen (block size)
    F_bn0: int  # tile size along k seqlen
    F_bk0: int  # tile size along gemm0 unroll(F_bhdq)
    F_bk1: int  # tile size along gemm1 unroll(F_bm0)
    F_bk2: int  # tile size along gemm2 unroll(F_bhdv)
    F_bk3: int  # tile size along gemm3 unroll(F_bm0)
    F_bk4: int  # tile size along gemm4 unroll(F_bn0)
    F_bhdq: int  # q head_dim
    F_bhdv: int  # v head_dim
    F_rm0: int  # number of warps along q seqlen (block warps) in gemm0/gemm2
    F_rn0: int  # number of warps along k seqlen (block warps) in gemm0/gemm2
    F_rk0: int  # number of warps along headdim_qk/v (not used) in gemm0/gemm2
    F_rm1: int  # number of warps along k seqlen (block warps) in gemm1/gemm3
    F_rn1: int  # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
    F_rk1: int  # number of warps along q seqlen (not used) in gemm1/gemm3
    F_rm2: int  # number of warps along q seqlen (block warps) in gemm4
    F_rn2: int  # number of warps along headdim_qk (block warps) in gemm4
    F_rk2: int  # number of warps along k seqlen (not used) in gemm4
    F_wm0: int  # warp size along m in gemm0/gemm2/gemm4
    F_wn0: int  # warp size along n in gemm0/gemm2/gemm4
    F_wk0: int  # warp size along k in gemm0/gemm2/gemm4
    F_wm1: int  # warp size along m in gemm1/gemm3
    F_wn1: int  # warp size along n in gemm1/gemm3
    F_wk1: int  # warp size along k in gemm1/gemm3
    F_occupancy: int  # occupancy

    @property
    def name(self) -> str:
        return (
            f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}"
            + f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}"
            + f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}"
        )


# TODO: design a more practical way to do it
# this is current supported tile size & pipeline.
# fmt: off
def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype: str) -> Optional[dict]:
    if dtype == "fp16" or dtype == "bf16":
        return {
            "32":  [FmhaBwdDQDKDVTileSize(32, 128, 32,  32, 32,  32, 64, 32,  32,  1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
            "64":  [FmhaBwdDQDKDVTileSize(32, 128, 64,  32, 64,  32, 32, 64,  64,  1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
            "128": [FmhaBwdDQDKDVTileSize(16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
            "256": [FmhaBwdDQDKDVTileSize(16, 64,  256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"],
        }
    else:
        return None
# fmt: on


FMHA_BWD_DOT_DO_O_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};

using fmha_bwd_dot_do_o_trait_{F_idx} =
    ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;

using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
    typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
    typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
    typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
    /* BlockSize = */ 64,
    {F_hdim},
    {F_mode},
    fmha_bwd_dot_do_o_trait_{F_idx}>;

using fmha_bwd_dot_do_o_{F_idx} =
    typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;

using fmha_bwd_dot_do_o_kernel_{F_idx} =
    ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;

using dot_do_o_trait_{F_idx} =
    fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;

#include <iostream>

template <>
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
    if(s.log_level_ > 0)
        std::cout << ", " << k_::GetName() << std::flush;
    auto [kargs, grids]                    = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
    const dim3 blocks                      = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    return ck_tile::launch_kernel(
        s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}

template <>
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    using k_                               = fmha_bwd_dot_do_o_kernel_{F_idx};
    auto [kargs, grids]                    = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
    const dim3 blocks                      = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
        ck_tile::stream_config{{s.stream_id_}});
}}

template <>
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
{{
    using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
    return k_::GetName();
}}
"""


@dataclass
class FmhaBwdOGradDotOKernel:
    F_idx: int  # this is not a tunable, but a counter to differentiate symbol
    F_hdim: int  # hdim
    F_dtype: str  # data type
    F_spad: str  # true/false
    F_dvpad: str  #
    F_mode: str  # value from MODE_MAP
    F_occupancy: int

    @property
    def template(self) -> str:
        return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
            F_idx=self.F_idx,
            F_hdim=self.F_hdim,
            F_dtype=BWD_DTYPE_MAP[self.F_dtype],
            F_spad=BOOL_MAP[self.F_spad],
            F_dvpad=BOOL_MAP[self.F_dvpad],
            F_mode=MODE_MAP[self.F_mode],
            F_occupancy=self.F_occupancy,
        )

    @property
    def name(self) -> str:
        def pad_name() -> str:
            n = ""
            if self.F_spad == "t":
                n += "s"
            if self.F_dvpad == "t":
                n += "dv"
            if n != "":
                n = "p" + n
            return n

        pn = pad_name()
        n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
        if pn != "":
            n += f"_{pn}"
        else:
            n += "_npad"
        return n

    @property
    def filename(self) -> str:
        return self.name + ".cpp"


def get_bwd_dot_do_o_blobs(
    kernel_filter: Optional[str],
) -> List[FmhaBwdOGradDotOKernel]:
    # TODO: we don't support tuning yet, so pick up one value for pad/occupancy
    #       support this in future
    def get_occupancy(dtype, hdim):
        return 2

    gen = list()

    for dtype in BWD_DTYPE_MAP.keys():
        d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
        if d is None:
            continue
        for hdim_str, mode, spad, dvpad in itertools.product(
            d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"]
        ):
            hdim = int(hdim_str)
            if mode == "group" and spad == "f":
                continue
            k = FmhaBwdOGradDotOKernel(
                F_idx=0,
                F_hdim=hdim,
                F_dtype=dtype,
                F_spad=spad,
                F_dvpad=dvpad,
                F_mode=mode,
                F_occupancy=get_occupancy(dtype, hdim),
            )
            if kernel_filter != "":
                if not fnmatch.fnmatch(k.name, kernel_filter):
                    continue
            gen.append(k)

    return gen


FMHA_BWD_CONVERT_DQ_KERNEL_BODY = """
using fmha_dtype_{F_idx} = {F_dtype};

using fmha_bwd_convert_dq_trait_{F_idx} =
    ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;

using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
    ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
        typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
        typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
        /* BlockSize = */ 256,
        {F_bm0},
        {F_bn0},
        {F_hdim},
        {F_mode},
        {F_deterministic},
        fmha_bwd_convert_dq_trait_{F_idx}>;

using fmha_bwd_convert_dq_{F_idx} =
    typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;

using fmha_bwd_convert_dq_kernel_{F_idx} =
    ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;

using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
                                                             {F_dtype},
                                                             {F_mode},
                                                             {F_spad},
                                                             {F_dpad},
                                                             {F_deterministic}>;

#include <iostream>

template <>
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
{{
    using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
    if(s.log_level_ > 0)
        std::cout << ", " << k_::GetName() << std::flush;
    auto [kargs, grids]                    = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
    const dim3 blocks                      = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    return ck_tile::launch_kernel(
        s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
}}

template <>
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
                                                            fmha_bwd_args a)
{{
    using k_                               = fmha_bwd_convert_dq_kernel_{F_idx};
    auto [kargs, grids]                    = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
    const dim3 blocks                      = k_::BlockSize();
    constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
    ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
        ck_tile::stream_config{{s.stream_id_}});
}}

template <>
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
{{
    using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
    return k_::GetName();
}}
"""


@dataclass
class FmhaBwdConvertQGradKernel:
    F_idx: int  # this is not a tunable, but a counter to differentiate symbol
    F_hdim: int  # hdim
    F_dtype: str  # data type
    F_bm0: int  # tile size along q seqlen (block size)
    F_bn0: int  # tile size along k seqlen
    F_spad: str  # true/false
    F_dpad: str  #
    F_mode: str  # value from MODE_MAP
    F_occupancy: int  #
    F_deterministic: str  #

    @property
    def template(self) -> str:
        return FMHA_BWD_KERNEL_HEADER + FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
            F_idx=self.F_idx,
            F_hdim=self.F_hdim,
            F_dtype=BWD_DTYPE_MAP[self.F_dtype],
            F_bm0=self.F_bm0,
            F_bn0=self.F_bn0,
            F_spad=BOOL_MAP[self.F_spad],
            F_dpad=BOOL_MAP[self.F_dpad],
            F_mode=MODE_MAP[self.F_mode],
            F_occupancy=self.F_occupancy,
            F_deterministic=BOOL_MAP[self.F_deterministic],
        )

    @property
    def name(self) -> str:
        def pad_name() -> str:
            n = ""
            if self.F_spad == "t":
                n += "s"
            if self.F_dpad == "t":
                n += "d"
            if n != "":
                n = "p" + n
            return n

        pn = pad_name()
        n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
        if pn != "":
            n += f"_{pn}"
        else:
            n += "_npad"
        if self.F_deterministic == "t":
            n += "_deterministic"
        else:
            n += "_ndeterministic"
        return n

    @property
    def filename(self) -> str:
        return self.name + ".cpp"


def get_bwd_convert_dq_blobs(
    kernel_filter: Optional[str],
) -> List[FmhaBwdConvertQGradKernel]:
    # TODO: we don't support tuning yet, so pick up one value for pad/occupancy
    #       support this in future
    def get_occupancy(dtype, hdim):
        return 2

    gen = list()

    for dtype in BWD_DTYPE_MAP.keys():
        d = get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype)
        if d is None:
            continue
        for hdim_str, mode, spad, dpad, deterministic in itertools.product(
            d.keys(), MODE_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]
        ):
            hdim = int(hdim_str)
            tile = d[hdim_str][0]
            if mode == "group" and spad == "f":
                continue
            k = FmhaBwdConvertQGradKernel(
                F_idx=0,
                F_hdim=hdim,
                F_dtype=dtype,
                F_bm0=64,
                F_bn0=tile.F_bn0,
                F_spad=spad,
                F_dpad=dpad,
                F_mode=mode,
                F_occupancy=get_occupancy(dtype, hdim),
                F_deterministic=deterministic,
            )
            if kernel_filter != "":
                if not fnmatch.fnmatch(k.name, kernel_filter):
                    continue
            gen.append(k)

    return gen


def write_single_bwd_dot_do_o_kernel(
    kernel: FmhaBwdOGradDotOKernel, autogen_dir: Path
) -> None:
    (autogen_dir / kernel.filename).write_text(kernel.template)


def write_single_bwd_convert_dq_kernel(
    kernel: FmhaBwdConvertQGradKernel, autogen_dir: Path
) -> None:
    (autogen_dir / kernel.filename).write_text(kernel.template)


def write_blobs(output_dir: Optional[str], filters_list: List[str]) -> None:
    if output_dir is None:
        output_dir = Path(__file__).parent
    else:
        output_dir = Path(output_dir) / GEN_DIR

    output_dir.mkdir(parents=True, exist_ok=True)

    for kernel_filter in filters_list:
        kernel_filter = kernel_filter.split("@")
        kernel_filter.extend([""] * (3 - len(kernel_filter)))

        kernels = get_bwd_dot_do_o_blobs(kernel_filter[0])
        for kernel in kernels:
            write_single_bwd_dot_do_o_kernel(kernel, output_dir)
        kernels = get_bwd_convert_dq_blobs(kernel_filter[1])
        for kernel in kernels:
            write_single_bwd_convert_dq_kernel(kernel, output_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="generate",
        description="gen API for CK fmha kernel",
    )
    parser.add_argument(
        "-o",
        "--output_dir",
        required=False,
        help="write all the blobs into a directory",
    )
    # TODO: if using filter, must apply same value to output_dir and list_blobs
    parser.add_argument(
        "-f",
        "--filter",
        default="",
        required=False,
        help="filter out kernels that need to generate, using fnmatch module",
    )

    args = parser.parse_args()
    filter_list = args.filter.split(",")

    write_blobs(args.output_dir, filter_list)
