{"metadata":{"kernelspec":{"language":"python","display_name":"Python 3","name":"python3"},"language_info":{"name":"python","version":"3.10.12","mimetype":"text/x-python","codemirror_mode":{"name":"ipython","version":3},"pygments_lexer":"ipython3","nbconvert_exporter":"python","file_extension":".py"},"kaggle":{"accelerator":"gpu","dataSources":[{"sourceId":84969,"databundleVersionId":10033515,"sourceType":"competition"},{"sourceId":9862305,"sourceType":"datasetVersion","datasetId":6052780},{"sourceId":9867543,"sourceType":"datasetVersion","datasetId":6040935},{"sourceId":9869730,"sourceType":"datasetVersion","datasetId":6058495},{"sourceId":10395422,"sourceType":"datasetVersion","datasetId":6440886},{"sourceId":10435988,"sourceType":"datasetVersion","datasetId":6462312},{"sourceId":10452453,"sourceType":"datasetVersion","datasetId":6470284},{"sourceId":10470773,"sourceType":"datasetVersion","datasetId":6483258},{"sourceId":10474567,"sourceType":"datasetVersion","datasetId":6485689},{"sourceId":10478633,"sourceType":"datasetVersion","datasetId":6488359},{"sourceId":10482716,"sourceType":"datasetVersion","datasetId":6490513},{"sourceId":10485103,"sourceType":"datasetVersion","datasetId":6491850},{"sourceId":10488089,"sourceType":"datasetVersion","datasetId":6493715},{"sourceId":10488137,"sourceType":"datasetVersion","datasetId":6493750},{"sourceId":10489284,"sourceType":"datasetVersion","datasetId":6494464},{"sourceId":10494499,"sourceType":"datasetVersion","datasetId":6497657},{"sourceId":10496075,"sourceType":"datasetVersion","datasetId":6498674},{"sourceId":10504523,"sourceType":"datasetVersion","datasetId":6503077},{"sourceId":10506804,"sourceType":"datasetVersion","datasetId":6504535},{"sourceId":10506838,"sourceType":"datasetVersion","datasetId":6504563},{"sourceId":10511964,"sourceType":"datasetVersion","datasetId":6506934},{"sourceId":10512215,"sourceType":"datasetVersion","datasetId":6507042},{"sourceId":10516725,"sourceType":"datasetVersion","datasetId":6509658},{"sourceId":10516731,"sourceType":"datasetVersion","datasetId":6509663},{"sourceId":10524637,"sourceType":"datasetVersion","datasetId":6513754},{"sourceId":10524950,"sourceType":"datasetVersion","datasetId":6513947},{"sourceId":10536804,"sourceType":"datasetVersion","datasetId":6519603},{"sourceId":10537309,"sourceType":"datasetVersion","datasetId":6519916},{"sourceId":10560529,"sourceType":"datasetVersion","datasetId":6533803},{"sourceId":10566113,"sourceType":"datasetVersion","datasetId":6538258},{"sourceId":10566172,"sourceType":"datasetVersion","datasetId":6538298},{"sourceId":10568425,"sourceType":"datasetVersion","datasetId":6539704},{"sourceId":10568434,"sourceType":"datasetVersion","datasetId":6539712},{"sourceId":10586989,"sourceType":"datasetVersion","datasetId":6552051},{"sourceId":10607541,"sourceType":"datasetVersion","datasetId":6566451},{"sourceId":10607811,"sourceType":"datasetVersion","datasetId":6566575}],"dockerImageVersionId":30823,"isInternetEnabled":false,"language":"python","sourceType":"notebook","isGpuEnabled":true}},"nbformat_minor":4,"nbformat":4,"cells":[{"cell_type":"markdown","source":"# Baseline UNet training + prediction/submission\n\n\nThis is the notebook I cobbled together to wrap my head around this challenge.\nI don't garuantee that the results are great, only that it works from end-to-end. \n\nIt trains a basic UNet and makes a submission. \n\nIt's based on these three notebooks: \n\n1. [3D U-Net : Training Only](https://www.kaggle.com/code/ahsuna123/3d-u-net-training-only)\n2. [3D U-Net PyTorch Lightning distributed training](https://www.kaggle.com/code/zhuowenzhao11/3d-u-net-pytorch-lightning-distributed-training)\n3. [3d-unet using 2d image encoder](https://www.kaggle.com/code/hengck23/3d-unet-using-2d-image-encoder/notebook)\n\n\nI've pre-computed the input data and stored them as numpy arrays so they don't have to be extracted every time the notebooks is run. ","metadata":{}},{"cell_type":"markdown","source":"## Installing offline deps\n\nAs this is a code comp, there is no internet. \nSo we have to do some silly things to get dependencies in here. \nWhy is asciitree such a PITA? ","metadata":{}},{"cell_type":"markdown","source":"https://claude.ai/chat/dfd761ea-85a8-4f11-b8d7-d55e22c2a38c","metadata":{}},{"cell_type":"code","source":"deps_path = '/kaggle/input/czii-cryoet-dependencies'","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:40:53.768588Z","iopub.execute_input":"2025-02-01T13:40:53.768799Z","iopub.status.idle":"2025-02-01T13:40:53.772658Z","shell.execute_reply.started":"2025-02-01T13:40:53.768778Z","shell.execute_reply":"2025-02-01T13:40:53.771842Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"! cp -r /kaggle/input/czii-cryoet-dependencies/asciitree-0.3.3/ asciitree-0.3.3/","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:40:53.773657Z","iopub.execute_input":"2025-02-01T13:40:53.773932Z","iopub.status.idle":"2025-02-01T13:40:53.957223Z","shell.execute_reply.started":"2025-02-01T13:40:53.773904Z","shell.execute_reply":"2025-02-01T13:40:53.956228Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"! pip wheel asciitree-0.3.3/asciitree-0.3.3/","metadata":{"_kg_hide-output":true,"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:40:53.958423Z","iopub.execute_input":"2025-02-01T13:40:53.958769Z","iopub.status.idle":"2025-02-01T13:40:57.223787Z","shell.execute_reply.started":"2025-02-01T13:40:53.958737Z","shell.execute_reply":"2025-02-01T13:40:57.222829Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"!pip install asciitree-0.3.3-py3-none-any.whl","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:40:57.224785Z","iopub.execute_input":"2025-02-01T13:40:57.225006Z","iopub.status.idle":"2025-02-01T13:41:01.186637Z","shell.execute_reply.started":"2025-02-01T13:40:57.224987Z","shell.execute_reply":"2025-02-01T13:41:01.185621Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"! pip install -q --no-index --find-links {deps_path} --requirement {deps_path}/requirements.txt","metadata":{"_kg_hide-output":true,"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:01.187666Z","iopub.execute_input":"2025-02-01T13:41:01.188011Z","iopub.status.idle":"2025-02-01T13:41:10.086363Z","shell.execute_reply.started":"2025-02-01T13:41:01.187978Z","shell.execute_reply":"2025-02-01T13:41:10.085552Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from typing import List, Tuple, Union\nimport numpy as np\nimport torch\nfrom monai.data import DataLoader, Dataset, CacheDataset, decollate_batch\nfrom monai.transforms import (\n    Compose,\n    EnsureChannelFirstd,\n    Orientationd,\n    RandFlipd,\n    RandRotate90d,\n    RandAffined,\n    RandGaussianNoised,\n    RandGaussianSmoothd,\n    RandScaleIntensityd,\n    RandShiftIntensityd,\n    RandAdjustContrastd,\n    RandHistogramShiftd,\n    RandCropByLabelClassesd,\n    NormalizeIntensityd,\n    RandZoomd,\n    AsDiscrete,\n)\nfrom monai.losses import (\n    DiceLoss,\n    DiceFocalLoss,\n    DiceCELoss,\n    TverskyLoss,\n    GeneralizedDiceLoss,\n    FocalLoss,\n)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:10.089175Z","iopub.execute_input":"2025-02-01T13:41:10.089397Z","iopub.status.idle":"2025-02-01T13:41:28.306871Z","shell.execute_reply.started":"2025-02-01T13:41:10.089377Z","shell.execute_reply":"2025-02-01T13:41:28.306190Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import os\nos.environ[\"JAX_PLATFORM_NAME\"] = \"cpu\"  # JAX를 CPU 모드로 설정","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:28.308102Z","iopub.execute_input":"2025-02-01T13:41:28.308834Z","iopub.status.idle":"2025-02-01T13:41:28.312337Z","shell.execute_reply.started":"2025-02-01T13:41:28.308810Z","shell.execute_reply":"2025-02-01T13:41:28.311602Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Define some helper functions\n\n\n### Patching helper functions\n\nThese are mostly used to split large volumes into smaller ones and stitch them back together. ","metadata":{}},{"cell_type":"code","source":"def calculate_patch_starts(dimension_size: int, patch_size: int) -> List[int]:\n    \"\"\"\n    Calculate the starting positions of patches along a single dimension\n    with minimal overlap to cover the entire dimension.\n    \n    Parameters:\n    -----------\n    dimension_size : int\n        Size of the dimension\n    patch_size : int\n        Size of the patch in this dimension\n        \n    Returns:\n    --------\n    List[int]\n        List of starting positions for patches\n    \"\"\"\n    if dimension_size <= patch_size:\n        return [0]\n        \n    # Calculate number of patches needed\n    n_patches = np.ceil(dimension_size / patch_size)\n    \n    if n_patches == 1:\n        return [0]\n    \n    # Calculate overlap\n    total_overlap = (n_patches * patch_size - dimension_size) / (n_patches - 1)\n    \n    # Generate starting positions\n    positions = []\n    for i in range(int(n_patches)):\n        pos = int(i * (patch_size - total_overlap))\n        if pos + patch_size > dimension_size:\n            pos = dimension_size - patch_size\n        if pos not in positions:  # Avoid duplicates\n            positions.append(pos)\n    \n    return positions\n\ndef extract_3d_patches_minimal_overlap(arrays: List[np.ndarray], patch_size: int) -> Tuple[List[np.ndarray], List[Tuple[int, int, int]]]:\n    \"\"\"\n    Extract 3D patches from multiple arrays with minimal overlap to cover the entire array.\n    \n    Parameters:\n    -----------\n    arrays : List[np.ndarray]\n        List of input arrays, each with shape (m, n, l)\n    patch_size : int\n        Size of cubic patches (a x a x a)\n        \n    Returns:\n    --------\n    patches : List[np.ndarray]\n        List of all patches from all input arrays\n    coordinates : List[Tuple[int, int, int]]\n        List of starting coordinates (x, y, z) for each patch\n    \"\"\"\n    if not arrays or not isinstance(arrays, list):\n        raise ValueError(\"Input must be a non-empty list of arrays\")\n    \n    # Verify all arrays have the same shape\n    shape = arrays[0].shape\n    if not all(arr.shape == shape for arr in arrays):\n        raise ValueError(\"All input arrays must have the same shape\")\n    \n    if patch_size > min(shape):\n        raise ValueError(f\"patch_size ({patch_size}) must be smaller than smallest dimension {min(shape)}\")\n    \n    m, n, l = shape\n    patches = []\n    coordinates = []\n    \n    # Calculate starting positions for each dimension\n    x_starts = calculate_patch_starts(m, patch_size)\n    y_starts = calculate_patch_starts(n, patch_size)\n    z_starts = calculate_patch_starts(l, patch_size)\n    \n    # Extract patches from each array\n    for arr in arrays:\n        for x in x_starts:\n            for y in y_starts:\n                for z in z_starts:\n                    patch = arr[\n                        x:x + patch_size,\n                        y:y + patch_size,\n                        z:z + patch_size\n                    ]\n                    patches.append(patch)\n                    coordinates.append((x, y, z))\n    \n    return patches, coordinates\n\n# Note: I should probably averge the overlapping areas, \n# but here they are just overwritten by the most recent one. \n\ndef reconstruct_array(patches: List[np.ndarray], \n                     coordinates: List[Tuple[int, int, int]], \n                     original_shape: Tuple[int, int, int]) -> np.ndarray:\n    \"\"\"\n    Reconstruct array from patches.\n    \n    Parameters:\n    -----------\n    patches : List[np.ndarray]\n        List of patches to reconstruct from\n    coordinates : List[Tuple[int, int, int]]\n        Starting coordinates for each patch\n    original_shape : Tuple[int, int, int]\n        Shape of the original array\n        \n    Returns:\n    --------\n    np.ndarray\n        Reconstructed array\n    \"\"\"\n    reconstructed = np.zeros(original_shape, dtype=np.int64)  # To track overlapping regions\n    \n    patch_size = patches[0].shape[0]\n    \n    for patch, (x, y, z) in zip(patches, coordinates):\n        reconstructed[\n            x:x + patch_size,\n            y:y + patch_size,\n            z:z + patch_size\n        ] = patch\n        \n    \n    return reconstructed","metadata":{"_kg_hide-input":true,"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:28.313141Z","iopub.execute_input":"2025-02-01T13:41:28.313765Z","iopub.status.idle":"2025-02-01T13:41:28.335506Z","shell.execute_reply.started":"2025-02-01T13:41:28.313733Z","shell.execute_reply":"2025-02-01T13:41:28.334829Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Submission helper functions\n\nThese help with getting the submission in the correct format","metadata":{}},{"cell_type":"code","source":"import pandas as pd\n\ndef dict_to_df(coord_dict, experiment_name):\n    \"\"\"\n    Convert dictionary of coordinates to pandas DataFrame.\n    \n    Parameters:\n    -----------\n    coord_dict : dict\n        Dictionary where keys are labels and values are Nx3 coordinate arrays\n        \n    Returns:\n    --------\n    pd.DataFrame\n        DataFrame with columns ['x', 'y', 'z', 'label']\n    \"\"\"\n    # Create lists to store data\n    all_coords = []\n    all_labels = []\n    \n    # Process each label and its coordinates\n    for label, coords in coord_dict.items():\n        all_coords.append(coords)\n        all_labels.extend([label] * len(coords))\n    \n    # Concatenate all coordinates\n    all_coords = np.vstack(all_coords)\n    \n    df = pd.DataFrame({\n        'experiment': experiment_name,\n        'particle_type': all_labels,\n        'x': all_coords[:, 0],\n        'y': all_coords[:, 1],\n        'z': all_coords[:, 2]\n    })\n\n    \n    return df","metadata":{"_kg_hide-input":true,"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:28.336294Z","iopub.execute_input":"2025-02-01T13:41:28.336591Z","iopub.status.idle":"2025-02-01T13:41:28.358382Z","shell.execute_reply.started":"2025-02-01T13:41:28.336563Z","shell.execute_reply":"2025-02-01T13:41:28.357690Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Reading in the data","metadata":{}},{"cell_type":"code","source":"TRAIN_DATA_DIR = \"/kaggle/input/czii-numpy-dataset-20250107\"\nTEST_DATA_DIR = \"/kaggle/input/czii-cryo-et-object-identification\"","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:28.359136Z","iopub.execute_input":"2025-02-01T13:41:28.359460Z","iopub.status.idle":"2025-02-01T13:41:28.376204Z","shell.execute_reply.started":"2025-02-01T13:41:28.359413Z","shell.execute_reply":"2025-02-01T13:41:28.375622Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Initialize the model\n\nThis model is pretty much directly copied from [3D U-Net PyTorch Lightning distributed training](https://www.kaggle.com/code/zhuowenzhao11/3d-u-net-pytorch-lightning-distributed-training)","metadata":{}},{"cell_type":"code","source":"import pytorch_lightning as pl\n\nfrom monai.networks.nets import UNet\nfrom monai.metrics import DiceMetric\nfrom torch.optim.lr_scheduler import (\n    CosineAnnealingWarmRestarts,\n    OneCycleLR,\n    ReduceLROnPlateau\n)\n\nclass Model(pl.LightningModule):\n    def __init__(self, spatial_dims=3, in_channels=1, out_channels=7,\n                 channels=(48, 64, 80, 80), strides=(2, 2, 1),\n                 num_res_units=1, lr=1e-3,\n                 scheduler_type='one_cycle'):\n            super().__init__()\n            self.save_hyperparameters()\n\n            # Model\n            self.model = UNet(\n              spatial_dims=self.hparams.spatial_dims,\n              in_channels=self.hparams.in_channels,\n              out_channels=self.hparams.out_channels,\n              channels=self.hparams.channels,\n              strides=self.hparams.strides,\n              num_res_units=self.hparams.num_res_units,\n              norm='batch',  # BatchNorm3d\n              dropout=0.2,\n          )\n\n            # Loss function\n            self.loss_fn = TverskyLoss(\n                include_background=True,\n                to_onehot_y=True,\n                softmax=True,\n                alpha=0.5,\n                beta=0.95\n            )\n\n            # Metric\n            self.metric_fn = DiceMetric(\n                include_background=False,\n                reduction=\"mean\",\n                get_not_nans=False\n            )\n\n            # Learning rate와 scheduler 설정\n            self.lr = lr\n            self.scheduler_type = scheduler_type\n\n            # 결과 저장용 리스트\n            self.training_step_outputs = []\n            self.validation_step_outputs = []\n\n            # Class weights 정의\n            self.class_weights = torch.tensor([1.0, 1.0, 1.0, 2.0, 2.0, 0.0])\n\n            # Storage for validation outputs\n            self.validation_outputs = []\n\n    def forward(self, x):\n        return self.model(x)\n\n    def validation_step(self, batch, batch_idx):\n        x, y = batch['image'], batch['label']\n        y_hat = self(x)\n        val_loss = self.loss_fn(y_hat, y)\n\n        metric_val_outputs = [AsDiscrete(argmax=True, to_onehot=self.hparams.out_channels)(i)\n                             for i in decollate_batch(y_hat)]\n        metric_val_labels = [AsDiscrete(to_onehot=self.hparams.out_channels)(i)\n                            for i in decollate_batch(y)]\n\n        self.metric_fn(y_pred=metric_val_outputs, y=metric_val_labels)\n        metrics = self.metric_fn.aggregate(reduction=\"mean_batch\")\n\n        # 클래스별 가중치를 device로 이동\n        class_weights = self.class_weights.to(metrics.device)\n\n        # 가중치가 적용된 전체 메트릭\n        weighted_metric = (metrics * class_weights).sum() / class_weights.sum()\n\n        # 로깅\n        self.log('val_loss', val_loss, on_step=False, on_epoch=True)\n        self.log('val_metric', weighted_metric, on_step=False, on_epoch=True)\n\n        output = {\n            'val_loss': val_loss.detach(),\n            'val_metric': weighted_metric.detach(),\n            'class_metrics': metrics.detach()\n        }\n\n        self.validation_outputs.append(output)\n\n        # 출력\n        print(f\"\\nEpoch {self.current_epoch}, Validation batch {batch_idx}\")\n        print(f\"Loss: {val_loss:.4f}, Metric: {weighted_metric:.4f}\")\n\n        return output\n\n    def training_step(self, batch, batch_idx):\n        x, y = batch['image'], batch['label']\n        y_hat = self(x)\n        loss = self.loss_fn(y_hat, y)\n\n        # 메트릭 로깅 추가\n        self.log(\"train_loss\", loss, on_step=False, on_epoch=True)  # 여기를 추가\n\n        print(f\"Epoch {self.current_epoch}, Training batch {batch_idx}, Loss: {loss.item():.4f}\")\n\n        self.training_step_outputs.append(loss)\n        return loss\n\n    def on_train_epoch_end(self):\n        epoch_mean = torch.stack(self.training_step_outputs).mean()\n        print(f\"\\n{'='*40}\")\n        print(f\"Epoch {self.current_epoch} Training completed\")\n        print(f\"Average training loss: {epoch_mean:.4f}\")\n        print(f\"{'='*40}\\n\")\n        self.training_step_outputs.clear()\n\n    def on_validation_epoch_start(self):\n        self.validation_outputs = []\n\n    def on_validation_epoch_end(self):\n        if not self.validation_outputs:\n            print(\"No validation outputs found!\")\n            return\n\n        try:\n            # 평균 계산\n            avg_loss = torch.stack([x['val_loss'] for x in self.validation_outputs]).mean()\n            avg_metric = torch.stack([x['val_metric'] for x in self.validation_outputs]).mean()\n            class_metrics = torch.stack([x['class_metrics'] for x in self.validation_outputs]).mean(dim=0)\n\n            print(f\"\\n{'='*70}\")\n            print(f\"Validation Epoch {self.current_epoch} Summary\")\n            print(f\"{'='*70}\")\n            print(f\"Average Loss: {avg_loss:.4f}\")\n            print(f\"Average Weighted Metric: {avg_metric:.4f}\")\n            print(\"\\nClass-wise Performance:\")\n            class_names = ['Ribosome', 'Virus-like', 'Apo-ferritin',\n                          'Thyroglobulin (Hard)', 'β-galactosidase (Hard)', 'Beta-amylase (Not evaluated)']\n\n            for i, (name, metric) in enumerate(zip(class_names, class_metrics)):\n                print(f\"  {name:<20} {metric:.4f}\")\n            print(f\"{'='*70}\\n\")\n\n        except Exception as e:\n            print(f\"Error in validation epoch end: {str(e)}\")\n        finally:\n            # 메트릭 리셋\n            self.metric_fn.reset()\n            self.validation_outputs = []\n\n    def configure_optimizers(self):\n      optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr)\n\n      if self.scheduler_type == 'one_cycle':\n          # dataloader가 설정되기 전에는 cosine scheduler를 사용\n          if not hasattr(self.trainer, 'train_dataloader') or self.trainer.train_dataloader is None:\n              print(\"Warning: train_dataloader not set, switching to cosine scheduler\")\n              scheduler = CosineAnnealingWarmRestarts(\n                  optimizer,\n                  T_0=10,\n                  T_mult=2,\n                  eta_min=1e-6\n              )\n              scheduler_config = {\n                  \"scheduler\": scheduler,\n                  \"interval\": \"epoch\",\n                  \"frequency\": 1\n              }\n          else:\n              steps_per_epoch = len(self.trainer.train_dataloader())\n              total_steps = steps_per_epoch * self.trainer.max_epochs\n\n              scheduler = OneCycleLR(\n                  optimizer,\n                  max_lr=self.lr,\n                  total_steps=total_steps,\n                  pct_start=0.3,\n                  div_factor=25.0,\n                  final_div_factor=1e4\n              )\n              scheduler_config = {\n                  \"scheduler\": scheduler,\n                  \"interval\": \"step\",\n                  \"frequency\": 1\n              }\n\n      elif self.scheduler_type == 'plateau':\n          scheduler = ReduceLROnPlateau(\n              optimizer,\n              mode='max',\n              factor=0.5,\n              patience=100,\n              min_lr=1e-6,\n              verbose=True\n          )\n          scheduler_config = {\n              \"scheduler\": scheduler,\n              \"interval\": \"epoch\",\n              \"monitor\": 'val_metric',\n              \"frequency\": 1\n          }\n\n      else:  # cosine as default\n          scheduler = CosineAnnealingWarmRestarts(\n              optimizer,\n              T_0=10,\n              T_mult=2,\n              eta_min=1e-6\n          )\n          scheduler_config = {\n              \"scheduler\": scheduler,\n              \"interval\": \"epoch\",\n              \"frequency\": 1\n          }\n\n      return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler_config}","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:28.377182Z","iopub.execute_input":"2025-02-01T13:41:28.377473Z","iopub.status.idle":"2025-02-01T13:41:29.498549Z","shell.execute_reply.started":"2025-02-01T13:41:28.377445Z","shell.execute_reply":"2025-02-01T13:41:29.497655Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"channels = (64, 128, 256, 256)\nstrides_pattern = (2, 2, 1)\nnum_res_units = 1\nlearning_rate = 1e-3\nnum_epochs = 1000\n\nmodel = Model(channels=channels, strides=strides_pattern, num_res_units=num_res_units, lr=learning_rate)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:29.499320Z","iopub.execute_input":"2025-02-01T13:41:29.499570Z","iopub.status.idle":"2025-02-01T13:41:29.585832Z","shell.execute_reply.started":"2025-02-01T13:41:29.499548Z","shell.execute_reply":"2025-02-01T13:41:29.584958Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Train the model\n\n","metadata":{}},{"cell_type":"markdown","source":"Let there be gradients!\n\nLocally this config seems to train for about 1000 steps before the model starts overfitting. ","metadata":{}},{"cell_type":"code","source":"# # 체크포인트 로드\n# (2) 체크포인트 로드\nckpt_paths = [\n    \"/kaggle/input/20250122-v34-7fold/fold0.ckpt\",\n    \"/kaggle/input/20250122-v34-7fold/fold1.ckpt\",\n    \"/kaggle/input/20250122-v34-7fold/fold2.ckpt\",\n]\n\n# 여러 모델을 불러와 리스트에 저장\nmodels_ensemble = []\nfor cp in ckpt_paths:\n    m = Model.load_from_checkpoint(cp)\n    m.eval()\n    m.to(\"cuda\")\n    models_ensemble.append(m)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:29.586704Z","iopub.execute_input":"2025-02-01T13:41:29.586972Z","iopub.status.idle":"2025-02-01T13:41:33.325318Z","shell.execute_reply.started":"2025-02-01T13:41:29.586942Z","shell.execute_reply":"2025-02-01T13:41:33.324691Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# # 학습 시작 전에 print 문 추가\n# print(\"Starting training...\")\n# trainer.fit(model, train_loader, valid_loader)\n# print(\"Training completed!\")\n# torch.save(model.state_dict(), 'final_model.pth')\n\n# # pth\n# model.load_state_dict(torch.load('/kaggle/input/20250115-cz/final_model.pth'))\n\n# # 모델을 평가 모드로 설정 (테스트 또는 추론 시)\n# model.eval()","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:33.326110Z","iopub.execute_input":"2025-02-01T13:41:33.326315Z","iopub.status.idle":"2025-02-01T13:41:33.329542Z","shell.execute_reply.started":"2025-02-01T13:41:33.326297Z","shell.execute_reply":"2025-02-01T13:41:33.328845Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"## Predict on the test set\n\n","metadata":{}},{"cell_type":"code","source":"import json\ncopick_config_path = TRAIN_DATA_DIR + \"/copick.config\"\n\nwith open(copick_config_path) as f:\n    copick_config = json.load(f)\n\ncopick_config['static_root'] = '/kaggle/input/czii-cryo-et-object-identification/test/static'\n\ncopick_test_config_path = 'copick_test.config'\n\nwith open(copick_test_config_path, 'w') as outfile:\n    json.dump(copick_config, outfile)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:33.330571Z","iopub.execute_input":"2025-02-01T13:41:33.330948Z","iopub.status.idle":"2025-02-01T13:41:33.359582Z","shell.execute_reply.started":"2025-02-01T13:41:33.330919Z","shell.execute_reply":"2025-02-01T13:41:33.358797Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import copick\n\nroot = copick.from_file(copick_test_config_path)\n\ncopick_user_name = \"copickUtils\"\ncopick_segmentation_name = \"paintedPicks\"\nvoxel_size = 10\ntomo_type = \"denoised\"","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:33.360414Z","iopub.execute_input":"2025-02-01T13:41:33.360669Z","iopub.status.idle":"2025-02-01T13:41:34.228203Z","shell.execute_reply.started":"2025-02-01T13:41:33.360639Z","shell.execute_reply":"2025-02-01T13:41:34.227356Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"# Non-random transforms to be cached\ninference_transforms = Compose([\n    EnsureChannelFirstd(keys=[\"image\"], channel_dim=\"no_channel\"),\n    NormalizeIntensityd(keys=\"image\"),\n    Orientationd(keys=[\"image\"], axcodes=\"RAS\")\n])","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:34.228975Z","iopub.execute_input":"2025-02-01T13:41:34.229827Z","iopub.status.idle":"2025-02-01T13:41:34.234083Z","shell.execute_reply.started":"2025-02-01T13:41:34.229801Z","shell.execute_reply":"2025-02-01T13:41:34.233160Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import cc3d\n\nid_to_name = {1: \"apo-ferritin\", \n              2: \"beta-amylase\",\n              3: \"beta-galactosidase\", \n              4: \"ribosome\", \n              5: \"thyroglobulin\", \n              6: \"virus-like-particle\"}","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:34.235015Z","iopub.execute_input":"2025-02-01T13:41:34.235281Z","iopub.status.idle":"2025-02-01T13:41:34.268334Z","shell.execute_reply.started":"2025-02-01T13:41:34.235250Z","shell.execute_reply":"2025-02-01T13:41:34.267420Z"}},"outputs":[],"execution_count":null},{"cell_type":"markdown","source":"### Iterate over test set\n\n\nBelow we will: \n1. Read in a run\n2. Split it into patches of size (96, 96, 96)\n3. Create a dataset from the patches\n4. Predict the segmentation mask\n5. Glue the mask back together\n6. Find the connected components for each class\n7. Find the centroids of the connected components\n8. Add to the dataframe\n\nThen do this for all runs. \n\nThis can probably be optimized quite a bit. ","metadata":{}},{"cell_type":"code","source":"import numpy as np\nimport torch\nimport cc3d\nimport time  # 시간 측정용\n\nfrom monai.inferers import sliding_window_inference\nfrom monai.data import CacheDataset\n\n# 간단한 3D flip 함수들\ndef flip_x_3d(tensor: torch.Tensor) -> torch.Tensor:\n    return torch.flip(tensor, dims=[2])  # x축 뒤집기\n\ndef flip_y_3d(tensor: torch.Tensor) -> torch.Tensor:\n    return torch.flip(tensor, dims=[3])  # y축 뒤집기\n\ndef flip_z_3d(tensor: torch.Tensor) -> torch.Tensor:\n    return torch.flip(tensor, dims=[4])  # z축 뒤집기\n\ndef identity_3d(tensor: torch.Tensor) -> torch.Tensor:\n    return tensor\n\n# (forward_transform, inverse_transform) 쌍 목록\ntta_ops = [\n    (identity_3d, identity_3d),\n    (flip_z_3d, flip_z_3d),\n]\n\ndef ensemble_tta_predictor(sub_volume: torch.Tensor) -> torch.Tensor:\n    \"\"\"\n    sub_volume : (B=1, C=1, D, H, W) 형태의 3D sub-volume Tensor\n    - 여기서 여러 모델(models_ensemble)에 대해, \n      TTA(fwd_op -> 모델추론 -> inv_op)를 모두 수행 후 평균내어 반환한다.\n    \n    반환 : (B=1, out_channels=7, D, H, W)\n    \"\"\"\n    all_logits = []\n\n    with torch.no_grad():\n        # (1) 앙상블 대상 모델들을 순회\n        for model in models_ensemble:\n            # (2) TTA 변환들 순회\n            for fwd_op, inv_op in tta_ops:\n                # 1) 변환\n                vol_t = fwd_op(sub_volume)          # (1,1,D,H,W)\n                # 2) 모델 추론(로짓)\n                logits = model(vol_t)               # (1,7,D,H,W)\n                # 3) 원래 좌표로 역변환\n                logits_inv = inv_op(logits)         # (1,7,D,H,W)\n\n                # 리스트에 저장\n                all_logits.append(logits_inv)\n\n    # 여러 모델 × 여러 TTA 로짓을 평균 -> (1,7,D,H,W)\n    final_logits = torch.mean(torch.stack(all_logits, dim=0), dim=0)\n    return final_logits\n\n# -----------------------------------------------------------\n# 슬라이딩 윈도우 + 앙상블 + TTA 인퍼런스 (후처리 없음) 예시\n# -----------------------------------------------------------\nBLOB_THRESHOLD = 250\nclasses = [1, 2, 3, 4, 5, 6]\n\n# 전체 파이프라인 시간 측정 시작\ntotal_start = time.time()\n\nwith torch.no_grad():\n    location_df = []\n\n    for run in root.runs:\n        run_start = time.time()  # 한 번의 run 시작 시간\n\n        print(run)\n\n        # 1) 볼륨(10Å voxel) 로드\n        load_start = time.time()\n        tomo = run.get_voxel_spacing(10)\n        tomo_arr = tomo.get_tomogram(tomo_type).numpy()  # shape: (X, Y, Z)\n        load_end = time.time()\n        print(f\"[Timer] Volume load time: {load_end - load_start:.3f} sec\")\n\n        # 2) Dataset 로드(전처리)\n        prep_start = time.time()\n        data_dict = [{\"image\": tomo_arr}]\n        tomo_ds = CacheDataset(data=data_dict, transform=inference_transforms, cache_rate=1.0)\n        volume_tensor = tomo_ds[0][\"image\"].unsqueeze(0).to(\"cuda\")  # (1,1,X,Y,Z)\n        prep_end = time.time()\n        print(f\"[Timer] Dataset prep time: {prep_end - prep_start:.3f} sec\")\n\n        # 3) Sliding Window Inference\n        #    -> predictor=ensemble_tta_predictor 로 교체\n        infer_start = time.time()\n        out_logits = sliding_window_inference(\n            inputs=volume_tensor,\n            roi_size=(128, 128, 128),\n            sw_batch_size=6,\n            predictor=ensemble_tta_predictor,  # <-- 여기서 앙상블+TTA 진행\n            overlap=0.25,\n            mode=\"gaussian\"\n        )\n        infer_end = time.time()\n        print(f\"[Timer] SW Inference(Ensemble+TTA) time: {infer_end - infer_start:.3f} sec\")\n\n        # 4) Softmax 후 argmax\n        post_start = time.time()\n        out_probs = torch.softmax(out_logits, dim=1)  # (1,7,X,Y,Z)\n        out_probs_np = out_probs[0].cpu().numpy()     # (7, X, Y, Z)\n        reconstructed_mask = np.argmax(out_probs_np, axis=0)  # (X, Y, Z)\n        post_end = time.time()\n        print(f\"[Timer] Postprocess(softmax+argmax) time: {post_end - post_start:.3f} sec\")\n\n        # 5) 라벨별 연결요소 → centroid 추출\n        cc_start = time.time()\n        location = {}\n        for c in classes:\n            cc = cc3d.connected_components(reconstructed_mask == c)\n            stats = cc3d.statistics(cc)\n            \n            # label=0은 background, 실제 오브젝트는 [1:]부터\n            zyx = stats[\"centroids\"][1:] * 10.012444  # voxel 크기 반영\n            zyx_large = zyx[stats[\"voxel_counts\"][1:] > BLOB_THRESHOLD]\n            xyz = np.ascontiguousarray(zyx_large[:, ::-1])\n            location[id_to_name[c]] = xyz\n        cc_end = time.time()\n        print(f\"[Timer] Connected components + centroids time: {cc_end - cc_start:.3f} sec\")\n\n        # 6) DataFrame 변환 후 저장\n        df = dict_to_df(location, run.name)\n        location_df.append(df)\n\n        run_end = time.time()\n        print(f\"[Timer] Single run total time: {run_end - run_start:.3f} sec\\n\")\n\n    # 모든 run 결과 결합\n    location_df = pd.concat(location_df)\n\n# 전체 파이프라인 시간 측정 종료\ntotal_end = time.time()\nprint(f\"전체 파이프라인 수행 시간: {total_end - total_start:.3f} 초\")\nprint(f'estimated predict time is {(total_end - total_start)/3*500:.4f} seconds')","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:41:34.269314Z","iopub.execute_input":"2025-02-01T13:41:34.270270Z","iopub.status.idle":"2025-02-01T13:44:41.763013Z","shell.execute_reply.started":"2025-02-01T13:41:34.270235Z","shell.execute_reply":"2025-02-01T13:44:41.762228Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"location_df.insert(loc=0, column='id', value=np.arange(len(location_df)))\nlocation_df.to_csv(\"submission.csv\", index=False)","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:44:41.765562Z","iopub.execute_input":"2025-02-01T13:44:41.765800Z","iopub.status.idle":"2025-02-01T13:44:41.783409Z","shell.execute_reply.started":"2025-02-01T13:44:41.765780Z","shell.execute_reply":"2025-02-01T13:44:41.782594Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"!ls","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:44:41.784374Z","iopub.execute_input":"2025-02-01T13:44:41.784618Z","iopub.status.idle":"2025-02-01T13:44:42.006936Z","shell.execute_reply.started":"2025-02-01T13:44:41.784587Z","shell.execute_reply":"2025-02-01T13:44:42.005960Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"!cp -r /kaggle/input/hengck-czii-cryo-et-01/* .","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:44:42.008218Z","iopub.execute_input":"2025-02-01T13:44:42.008574Z","iopub.status.idle":"2025-02-01T13:44:42.925997Z","shell.execute_reply.started":"2025-02-01T13:44:42.008540Z","shell.execute_reply":"2025-02-01T13:44:42.924976Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"from czii_helper import *\nfrom dataset import *\nfrom scipy.optimize import linear_sum_assignment\nimport matplotlib.pyplot as plt","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:44:42.926972Z","iopub.execute_input":"2025-02-01T13:44:42.927206Z","iopub.status.idle":"2025-02-01T13:44:42.934161Z","shell.execute_reply.started":"2025-02-01T13:44:42.927186Z","shell.execute_reply":"2025-02-01T13:44:42.933506Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"import os\nif os.getenv('KAGGLE_IS_COMPETITION_RERUN'):\n    MODE = 'submit'\nelse:\n    MODE = 'local'\n\n\n\n\n\n\n\nvalid_dir ='/kaggle/input/czii-cryo-et-object-identification/train'\nvalid_id = ['TS_6_4', ]\n\ndef do_one_eval(truth, predict, threshold):\n    P=len(predict)\n    T=len(truth)\n\n    if P==0:\n        hit=[[],[]]\n        miss=np.arange(T).tolist()\n        fp=[]\n        metric = [P,T,len(hit[0]),len(miss),len(fp)]\n        return hit, fp, miss, metric\n\n    if T==0:\n        hit=[[],[]]\n        fp=np.arange(P).tolist()\n        miss=[]\n        metric = [P,T,len(hit[0]),len(miss),len(fp)]\n        return hit, fp, miss, metric\n\n    #---\n    distance = predict.reshape(P,1,3)-truth.reshape(1,T,3)\n    distance = distance**2\n    distance = distance.sum(axis=2)\n    distance = np.sqrt(distance)\n    p_index, t_index = linear_sum_assignment(distance)\n\n    valid = distance[p_index, t_index] <= threshold\n    p_index = p_index[valid]\n    t_index = t_index[valid]\n    hit = [p_index.tolist(), t_index.tolist()]\n    miss = np.arange(T)\n    miss = miss[~np.isin(miss,t_index)].tolist()\n    fp = np.arange(P)\n    fp = fp[~np.isin(fp,p_index)].tolist()\n\n    metric = [P,T,len(hit[0]),len(miss),len(fp)] #for lb metric F-beta copmutation\n    return hit, fp, miss, metric\n\n\ndef compute_lb(submit_df, overlay_dir):\n    valid_id = list(submit_df['experiment'].unique())\n    print(valid_id)\n\n    eval_df = []\n    for id in valid_id:\n        truth = read_one_truth(id, overlay_dir) #=f'{valid_dir}/overlay/ExperimentRuns')\n        id_df = submit_df[submit_df['experiment'] == id]\n        for p in PARTICLE:\n            p = dotdict(p)\n            print('\\r', id, p.name, end='', flush=True)\n            xyz_truth = truth[p.name]\n            xyz_predict = id_df[id_df['particle_type'] == p.name][['x', 'y', 'z']].values\n            hit, fp, miss, metric = do_one_eval(xyz_truth, xyz_predict, p.radius* 0.5)\n            eval_df.append(dotdict(\n                id=id, particle_type=p.name,\n                P=metric[0], T=metric[1], hit=metric[2], miss=metric[3], fp=metric[4],\n            ))\n    print('')\n    eval_df = pd.DataFrame(eval_df)\n    gb = eval_df.groupby('particle_type').agg('sum').drop(columns=['id'])\n    gb.loc[:, 'precision'] = gb['hit'] / gb['P']\n    gb.loc[:, 'precision'] = gb['precision'].fillna(0)\n    gb.loc[:, 'recall'] = gb['hit'] / gb['T']\n    gb.loc[:, 'recall'] = gb['recall'].fillna(0)\n    gb.loc[:, 'f-beta4'] = 17 * gb['precision'] * gb['recall'] / (16 * gb['precision'] + gb['recall'])\n    gb.loc[:, 'f-beta4'] = gb['f-beta4'].fillna(0)\n\n    gb = gb.sort_values('particle_type').reset_index(drop=False)\n    # https://www.kaggle.com/competitions/czii-cryo-et-object-identification/discussion/544895\n    gb.loc[:, 'weight'] = [1, 0, 2, 1, 2, 1]\n    lb_score = (gb['f-beta4'] * gb['weight']).sum() / gb['weight'].sum()\n    return gb, lb_score\n\n\n#debug\nif 1:\n    if MODE=='local':\n    #if 1:\n        submit_df=pd.read_csv(\n           'submission.csv'\n            # '/kaggle/input/hengck-czii-cryo-et-weights-01/submission.csv'\n        )\n        gb, lb_score = compute_lb(submit_df, f'{valid_dir}/overlay/ExperimentRuns')\n        print(gb)\n        print('lb_score:',lb_score)\n        print('')\n\n\n        #show one ----------------------------------\n        fig = plt.figure(figsize=(18, 8))\n\n        id = valid_id[0]\n        truth = read_one_truth(id,overlay_dir=f'{valid_dir}/overlay/ExperimentRuns')\n\n        submit_df = submit_df[submit_df['experiment']==id]\n        for p in PARTICLE:\n            p = dotdict(p)\n            xyz_truth = truth[p.name]\n            xyz_predict = submit_df[submit_df['particle_type']==p.name][['x','y','z']].values\n            hit, fp, miss, _ = do_one_eval(xyz_truth, xyz_predict, p.radius)\n            print(id, p.name)\n            print('\\t num truth   :',len(xyz_truth) )\n            print('\\t num predict :',len(xyz_predict) )\n            print('\\t num hit  :',len(hit[0]) )\n            print('\\t num fp   :',len(fp) )\n            print('\\t num miss :',len(miss) )\n\n            ax = fig.add_subplot(2, 3, p.label, projection='3d')\n            if hit[0]:\n                pt = xyz_predict[hit[0]]\n                ax.scatter(pt[:, 0], pt[:, 1], pt[:, 2], alpha=0.5, color='r')\n                pt = xyz_truth[hit[1]]\n                ax.scatter(pt[:,0], pt[:,1], pt[:,2], s=80, facecolors='none', edgecolors='r')\n            if fp:\n                pt = xyz_predict[fp]\n                ax.scatter(pt[:, 0], pt[:, 1], pt[:, 2], alpha=1, color='k')\n            if miss:\n                pt = xyz_truth[miss]\n                ax.scatter(pt[:, 0], pt[:, 1], pt[:, 2], s=160, alpha=1, facecolors='none', edgecolors='k')\n\n            ax.set_title(f'{p.name} ({p.difficulty})')\n\n        plt.tight_layout()\n        plt.show()\n        \n        #--- \n        zz=0","metadata":{"trusted":true,"execution":{"iopub.status.busy":"2025-02-01T13:44:42.935057Z","iopub.execute_input":"2025-02-01T13:44:42.935348Z","iopub.status.idle":"2025-02-01T13:44:44.295506Z","shell.execute_reply.started":"2025-02-01T13:44:42.935318Z","shell.execute_reply":"2025-02-01T13:44:44.294520Z"}},"outputs":[],"execution_count":null},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null},{"cell_type":"code","source":"","metadata":{"trusted":true},"outputs":[],"execution_count":null}]}