Source code for tidy3d.plugins.autograd.invdes.filters

from __future__ import annotations

import abc
from functools import lru_cache, partial
from typing import Annotated, Callable, Iterable, Tuple, Union

import numpy as np
import pydantic.v1 as pd
from numpy.typing import NDArray

import tidy3d as td
from tidy3d.components.base import Tidy3dBaseModel
from tidy3d.components.types import TYPE_TAG_STR

from ..functions import convolve
from ..types import KernelType, PaddingType
from ..utilities import get_kernel_size_px, make_kernel


class AbstractFilter(Tidy3dBaseModel, abc.ABC):
    """An abstract class for creating and applying convolution filters."""

    kernel_size: Union[pd.PositiveInt, Tuple[pd.PositiveInt, ...]] = pd.Field(
        ..., title="Kernel Size", description="Size of the kernel in pixels for each dimension."
    )
    normalize: bool = pd.Field(
        True, title="Normalize", description="Whether to normalize the kernel so that it sums to 1."
    )
    padding: PaddingType = pd.Field(
        "reflect", title="Padding", description="The padding mode to use."
    )

    @classmethod
    def from_radius_dl(
        cls, radius: Union[float, Tuple[float, ...]], dl: Union[float, Tuple[float, ...]], **kwargs
    ) -> AbstractFilter:
        """Create a filter from radius and grid spacing.

        Parameters
        ----------
        radius : Union[float, Tuple[float, ...]]
            The radius of the kernel. Can be a scalar or a tuple.
        dl : Union[float, Tuple[float, ...]]
            The grid spacing. Can be a scalar or a tuple.
        **kwargs
            Additional keyword arguments to pass to the filter constructor.

        Returns
        -------
        AbstractFilter
            An instance of the filter.
        """
        kernel_size = get_kernel_size_px(radius=radius, dl=dl)
        return cls(kernel_size=kernel_size, **kwargs)

    @staticmethod
    @abc.abstractmethod
    def get_kernel(size_px: Iterable[int], normalize: bool) -> NDArray:
        """Get the kernel for the filter.

        Parameters
        ----------
        size_px : Iterable[int]
            Size of the kernel in pixels for each dimension.
        normalize : bool
            Whether to normalize the kernel so that it sums to 1.

        Returns
        -------
        np.ndarray
            The kernel.
        """

    def __call__(self, array: NDArray) -> NDArray:
        """Apply the filter to an input array.

        Parameters
        ----------
        array : np.ndarray
            The input array to filter.

        Returns
        -------
        np.ndarray
            The filtered array.
        """
        original_shape = array.shape
        squeezed_array = np.squeeze(array)
        size_px = tuple(np.atleast_1d(self.kernel_size))
        if len(size_px) != squeezed_array.ndim:
            size_px *= squeezed_array.ndim
        kernel = self.get_kernel(size_px, self.normalize)
        convolved_array = convolve(squeezed_array, kernel, padding=self.padding)
        return np.reshape(convolved_array, original_shape)


class ConicFilter(AbstractFilter):
    """A conic filter for creating and applying convolution filters."""

    @staticmethod
    @lru_cache(maxsize=1)
    def get_kernel(size_px: Iterable[int], normalize: bool) -> NDArray:
        """Get the conic kernel.

        See Also
        --------
        :func:`~filters.AbstractFilter.get_kernel` for full method documentation.
        """
        return make_kernel(kernel_type="conic", size=size_px, normalize=normalize)


class CircularFilter(AbstractFilter):
    """A circular filter for creating and applying convolution filters."""

    @staticmethod
    @lru_cache(maxsize=1)
    def get_kernel(size_px: Iterable[int], normalize: bool) -> NDArray:
        """Get the circular kernel.

        See Also
        --------
        :func:`~filters.AbstractFilter.get_kernel` for full method documentation.
        """
        return make_kernel(kernel_type="circular", size=size_px, normalize=normalize)


def _get_kernel_size(
    radius: Union[float, Tuple[float, ...]],
    dl: Union[float, Tuple[float, ...]],
    size_px: Union[int, Tuple[int, ...]],
) -> Tuple[int, ...]:
    """Determine the kernel size based on the provided radius, grid spacing, or size in pixels.

    Parameters
    ----------
    radius : Union[float, Tuple[float, ...]]
        The radius of the kernel. Can be a scalar or a tuple.
    dl : Union[float, Tuple[float, ...]]
        The grid spacing. Can be a scalar or a tuple.
    size_px : Union[int, Tuple[int, ...]]
        The size of the kernel in pixels for each dimension. Can be a scalar or a tuple.

    Returns
    -------
    Tuple[int, ...]
        The size of the kernel in pixels for each dimension.

    Raises
    ------
    ValueError
        If neither ``size_px`` nor both ``radius`` and ``dl`` are provided.
    """
    if size_px is not None:
        if radius is not None and dl is not None:
            td.log.warning(
                "Both 'size_px' and 'radius' and 'dl' are provided. 'size_px' will take precedence."
            )
        return (size_px,) if np.isscalar(size_px) else tuple(size_px)
    elif radius is not None and dl is not None:
        kernel_size = get_kernel_size_px(radius=radius, dl=dl)
        return (kernel_size,) if np.isscalar(kernel_size) else tuple(kernel_size)
    else:
        raise ValueError("Either 'size_px' or both 'radius' and 'dl' must be provided.")


[docs] def make_filter( radius: Union[float, Tuple[float, ...]] = None, dl: Union[float, Tuple[float, ...]] = None, *, size_px: Union[int, Tuple[int, ...]] = None, normalize: bool = True, padding: PaddingType = "reflect", filter_type: KernelType, ) -> Callable[[NDArray], NDArray]: """Create a filter function based on the specified kernel type and size. Parameters ---------- radius : Union[float, Tuple[float, ...]] = None The radius of the kernel. Can be a scalar or a tuple. dl : Union[float, Tuple[float, ...]] = None The grid spacing. Can be a scalar or a tuple. size_px : Union[int, Tuple[int, ...]] = None The size of the kernel in pixels for each dimension. Can be a scalar or a tuple. normalize : bool = True Whether to normalize the kernel so that it sums to 1. padding : PaddingType = "reflect" The padding mode to use. filter_type : KernelType The type of kernel to create (``circular`` or ``conic``). Returns ------- Callable[[np.ndarray], np.ndarray] A function that applies the created filter to an input array. """ kernel_size = _get_kernel_size(radius, dl, size_px) if filter_type == "conic": filter_class = ConicFilter elif filter_type == "circular": filter_class = CircularFilter else: raise ValueError( f"Unsupported filter_type: {filter_type}. " "Must be one of `CircularFilter` or `ConicFilter`." ) filter_instance = filter_class(kernel_size=kernel_size, normalize=normalize, padding=padding) return filter_instance
make_conic_filter = partial(make_filter, filter_type="conic") make_conic_filter.__doc__ = """make_filter() with a default filter_type value of ``conic``. See Also -------- :func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size. """ make_circular_filter = partial(make_filter, filter_type="circular") make_circular_filter.__doc__ = """make_filter() with a default filter_type value of `circular`. See Also -------- :func:`~filters.make_filter` : Function to create a filter based on the specified kernel type and size. """ FilterType = Annotated[Union[ConicFilter, CircularFilter], pd.Field(discriminator=TYPE_TAG_STR)]