"""Defines a jax-compatible structure and its conversion to a gradient monitor."""
from __future__ import annotations
from typing import Union
import numpy as np
import pydantic.v1 as pd
from jax.tree_util import register_pytree_node_class
from tidy3d.components.data.monitor_data import FieldData, PermittivityData
from tidy3d.components.geometry.utils import GeometryType
from tidy3d.components.medium import MediumType
from tidy3d.components.monitor import FieldMonitor
from tidy3d.components.structure import Structure
from tidy3d.components.types import TYPE_TAG_STR, Bound
from tidy3d.constants import C_0
from .base import JaxObject
from .geometry import JAX_GEOMETRY_MAP, JaxBox, JaxGeometryType
from .medium import JAX_MEDIUM_MAP, JaxMediumType
GEO_MED_MAPPINGS = {"geometry": JAX_GEOMETRY_MAP, "medium": JAX_MEDIUM_MAP}
class AbstractJaxStructure(Structure, JaxObject):
"""A :class:`.Structure` registered with jax."""
_tidy3d_class = Structure
# which of "geometry" or "medium" is differentiable for this class
_differentiable_fields = ()
geometry: Union[JaxGeometryType, GeometryType]
medium: Union[JaxMediumType, MediumType]
@pd.validator("medium", always=True)
def _check_2d_geometry(cls, val, values):
"""Override validator checking 2D geometry, which triggers unnecessarily for gradients."""
return val
def _validate_web_adjoint(self) -> None:
"""Run validators for this component, only if using ``tda.web.run()``."""
if "geometry" in self._differentiable_fields:
self.geometry._validate_web_adjoint()
if "medium" in self._differentiable_fields:
self.medium._validate_web_adjoint()
@property
def jax_fields(self):
"""The fields that are jax-traced for this class."""
return {"geometry": self.geometry, "medium": self.medium}
@property
def exclude_fields(self):
"""Fields to exclude from the self dict."""
return {"type", *list(self.jax_fields.keys())}
def to_structure(self) -> Structure:
"""Convert :class:`.JaxStructure` instance to :class:`.Structure`"""
self_dict = self.dict(exclude=self.exclude_fields)
for key, component in self.jax_fields.items():
if key in self._differentiable_fields:
self_dict[key] = component.to_tidy3d()
else:
self_dict[key] = component
return Structure.parse_obj(self_dict)
@classmethod
def from_structure(cls, structure: Structure) -> JaxStructure:
"""Convert :class:`.Structure` to :class:`.JaxStructure`."""
struct_dict = structure.dict(exclude={"type"})
jax_fields = {"geometry": structure.geometry, "medium": structure.medium}
for key, component in jax_fields.items():
if key in cls._differentiable_fields:
type_map = GEO_MED_MAPPINGS[key]
jax_type = type_map[type(component)]
struct_dict[key] = jax_type.from_tidy3d(component)
else:
struct_dict[key] = component
return cls.parse_obj(struct_dict)
def make_grad_monitors(self, freqs: list[float], name: str) -> FieldMonitor:
"""Return gradient monitor associated with this object."""
if "geometry" not in self._differentiable_fields:
# make a fake JaxBox to be able to call .make_grad_monitors
rmin, rmax = self.geometry.bounds
geometry = JaxBox.from_bounds(rmin=rmin, rmax=rmax)
else:
geometry = self.geometry
return geometry.make_grad_monitors(freqs=freqs, name=name)
def _get_medium_params(
self,
grad_data_eps: PermittivityData,
) -> dict[str, float]:
"""Compute params in the material of this structure."""
freq_max = float(max(grad_data_eps.eps_xx.f))
eps_in = self.medium.eps_model(frequency=freq_max)
ref_ind = np.sqrt(np.max(np.real(eps_in)))
ref_ind = max([1.0, abs(ref_ind)])
wvl_free_space = C_0 / freq_max
wvl_mat = wvl_free_space / ref_ind
return {"wvl_mat": wvl_mat, "eps_in": eps_in}
def geometry_vjp(
self,
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
grad_data_eps: PermittivityData,
sim_bounds: Bound,
eps_out: complex,
num_proc: int = 1,
) -> JaxGeometryType:
"""Compute the VJP for the structure geometry."""
medium_params = self._get_medium_params(grad_data_eps=grad_data_eps)
return self.geometry.store_vjp(
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
grad_data_eps=grad_data_eps,
sim_bounds=sim_bounds,
wvl_mat=medium_params["wvl_mat"],
eps_out=eps_out,
eps_in=medium_params["eps_in"],
num_proc=num_proc,
)
def medium_vjp(
self,
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
grad_data_eps: PermittivityData,
sim_bounds: Bound,
) -> JaxMediumType:
"""Compute the VJP for the structure medium."""
medium_params = self._get_medium_params(grad_data_eps=grad_data_eps)
return self.medium.store_vjp(
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
sim_bounds=sim_bounds,
wvl_mat=medium_params["wvl_mat"],
inside_fn=self.geometry.inside,
)
def store_vjp(
self,
# field_keys: List[Literal["medium", "geometry"]],
grad_data_fwd: FieldData,
grad_data_adj: FieldData,
grad_data_eps: PermittivityData,
sim_bounds: Bound,
eps_out: complex,
num_proc: int = 1,
) -> JaxStructure:
"""Returns the gradient of the structure parameters given forward and adjoint field data."""
# return right away if field_keys are not present for some reason
if not self._differentiable_fields:
return self
vjp_dict = {}
# compute minimum wavelength in material (to use for determining integration points)
if "geometry" in self._differentiable_fields:
vjp_dict["geometry"] = self.geometry_vjp(
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
grad_data_eps=grad_data_eps,
sim_bounds=sim_bounds,
eps_out=eps_out,
num_proc=num_proc,
)
if "medium" in self._differentiable_fields:
vjp_dict["medium"] = self.medium_vjp(
grad_data_fwd=grad_data_fwd,
grad_data_adj=grad_data_adj,
grad_data_eps=grad_data_eps,
sim_bounds=sim_bounds,
)
return self.updated_copy(**vjp_dict)
[docs]
@register_pytree_node_class
class JaxStructure(AbstractJaxStructure, JaxObject):
"""A :class:`.Structure` registered with jax."""
geometry: JaxGeometryType = pd.Field(
...,
title="Geometry",
description="Geometry of the structure, which is jax-compatible.",
jax_field=True,
discriminator=TYPE_TAG_STR,
)
medium: JaxMediumType = pd.Field(
...,
title="Medium",
description="Medium of the structure, which is jax-compatible.",
jax_field=True,
discriminator=TYPE_TAG_STR,
)
_differentiable_fields = ("medium", "geometry")
@register_pytree_node_class
class JaxStructureStaticMedium(AbstractJaxStructure, JaxObject):
"""A :class:`.Structure` registered with jax."""
geometry: JaxGeometryType = pd.Field(
...,
title="Geometry",
description="Geometry of the structure, which is jax-compatible.",
jax_field=True,
discriminator=TYPE_TAG_STR,
)
medium: MediumType = pd.Field(
...,
title="Medium",
description="Regular ``tidy3d`` medium of the structure, non differentiable. "
"Supports dispersive materials.",
jax_field=False,
discriminator=TYPE_TAG_STR,
)
_differentiable_fields = ("geometry",)
@register_pytree_node_class
class JaxStructureStaticGeometry(AbstractJaxStructure, JaxObject):
"""A :class:`.Structure` registered with jax."""
geometry: GeometryType = pd.Field(
...,
title="Geometry",
description="Regular ``tidy3d`` geometry of the structure, non differentiable. "
"Supports angled sidewalls and other complex geometries.",
jax_field=False,
discriminator=TYPE_TAG_STR,
)
medium: JaxMediumType = pd.Field(
...,
title="Medium",
description="Medium of the structure, which is jax-compatible.",
jax_field=True,
discriminator=TYPE_TAG_STR,
)
_differentiable_fields = ("medium",)
JaxStructureType = Union[JaxStructure, JaxStructureStaticMedium, JaxStructureStaticGeometry]