from typing import List

import torch


def check_input(x: torch.Tensor):
    assert x.is_cuda, f"{str(x)} must be a CUDA Tensor"
    assert x.is_contiguous(), f"{str(x)} must be contiguous"


def check_dim(d, x: torch.Tensor):
    assert x.dim() == d, f"{str(x)} must be a {d}D tensor"


def check_shape(a: torch.Tensor, b: torch.Tensor):
    assert a.dim() == b.dim(), "tensors should have same dim"
    for i in range(a.dim()):
        assert a.size(i) == b.size(i), (
            f"tensors shape mismatch, {a.size()} and {b.size()}"
        )


def check_device(tensors: List[torch.Tensor]):
    device = tensors[0].device
    for t in tensors:
        assert t.device == device, (
            f"All tensors should be on the same device, but got {device} and {t.device}"
        )
