Source code for tidy3d.plugins.smatrix.ports.wave

"""Class and custom data array for representing a scattering matrix wave port."""

from __future__ import annotations

from typing import Optional, Union

import numpy as np
import pydantic.v1 as pd

from tidy3d.components.base import cached_property, skip_if_fields_missing
from tidy3d.components.data.data_array import FreqDataArray, FreqModeDataArray
from tidy3d.components.data.monitor_data import ModeData
from tidy3d.components.data.sim_data import SimulationData
from tidy3d.components.geometry.base import Box
from tidy3d.components.geometry.bound_ops import bounds_contains
from tidy3d.components.grid.grid import Grid
from tidy3d.components.monitor import ModeMonitor
from tidy3d.components.simulation import Simulation
from tidy3d.components.source.field import ModeSource, ModeSpec
from tidy3d.components.source.time import GaussianPulse
from tidy3d.components.structure import MeshOverrideStructure
from tidy3d.components.types import Axis, Direction, FreqArray
from tidy3d.constants import fp_eps
from tidy3d.exceptions import ValidationError
from tidy3d.plugins.microwave import CurrentIntegralTypes, ImpedanceCalculator, VoltageIntegralTypes
from tidy3d.plugins.mode import ModeSolver

from .base_terminal import AbstractTerminalPort

DEFAULT_WAVE_PORT_NUM_CELLS = 5
MIN_WAVE_PORT_NUM_CELLS = 3


[docs] class WavePort(AbstractTerminalPort, Box): """Class representing a single wave port""" direction: Direction = pd.Field( ..., title="Direction", description="'+' or '-', defining which direction is considered 'input'.", ) mode_spec: ModeSpec = pd.Field( ModeSpec(), title="Mode Specification", description="Parameters to feed to mode solver which determine modes measured by monitor.", ) mode_index: pd.NonNegativeInt = pd.Field( 0, title="Mode Index", description="Index into the collection of modes returned by mode solver. " " Specifies which mode to inject using this source. " "If larger than ``mode_spec.num_modes``, " "``num_modes`` in the solver will be set to ``mode_index + 1``.", ) voltage_integral: Optional[VoltageIntegralTypes] = pd.Field( None, title="Voltage Integral", description="Definition of voltage integral used to compute voltage and the characteristic impedance.", ) current_integral: Optional[CurrentIntegralTypes] = pd.Field( None, title="Current Integral", description="Definition of current integral used to compute current and the characteristic impedance.", ) num_grid_cells: Optional[int] = pd.Field( DEFAULT_WAVE_PORT_NUM_CELLS, ge=MIN_WAVE_PORT_NUM_CELLS, title="Number of Grid Cells", description="Number of mesh grid cells in the transverse plane of the `WavePort`. " "Used in generating the suggested list of :class:`.MeshOverrideStructure` objects. " "Must be greater than or equal to 3. When set to `None`, no grid refinement is performed.", ) def _mode_voltage_coefficients(self, mode_data: ModeData) -> FreqModeDataArray: """Calculates scaling coefficients to convert mode amplitudes to the total port voltage. """ mode_data = mode_data._isel(mode_index=[self.mode_index]) if self.voltage_integral is None: current_coeffs = self.current_integral.compute_current(mode_data) voltage_coeffs = 2 * np.abs(mode_data.flux) / np.conj(current_coeffs) else: voltage_coeffs = self.voltage_integral.compute_voltage(mode_data) return voltage_coeffs.squeeze() def _mode_current_coefficients(self, mode_data: ModeData) -> FreqModeDataArray: """Calculates scaling coefficients to convert mode amplitudes to the total port current. """ mode_data = mode_data._isel(mode_index=[self.mode_index]) if self.current_integral is None: voltage_coeffs = self.voltage_integral.compute_voltage(mode_data) current_coeffs = (2 * np.abs(mode_data.flux) / voltage_coeffs).conj() else: current_coeffs = self.current_integral.compute_current(mode_data) return current_coeffs.squeeze() @cached_property def injection_axis(self) -> Axis: """Injection axis of the port.""" return self.size.index(0.0) @cached_property def transverse_axes(self) -> tuple[Axis, Axis]: """Transverse axes of the port.""" _, trans_axes = Box.pop_axis([0, 1, 2], self.injection_axis) return trans_axes @cached_property def _mode_monitor_name(self) -> str: """Return the name of the :class:`.ModeMonitor` associated with this port.""" return f"{self.name}_mode"
[docs] def to_source( self, source_time: GaussianPulse, snap_center: Optional[float] = None ) -> ModeSource: """Create a mode source from the wave port.""" center = list(self.center) if snap_center: center[self.injection_axis] = snap_center return ModeSource( center=center, size=self.size, source_time=source_time, mode_spec=self.mode_spec, mode_index=self.mode_index, direction=self.direction, name=self.name, )
[docs] def to_monitors( self, freqs: FreqArray, snap_center: Optional[float] = None, grid: Grid = None ) -> list[ModeMonitor]: """The wave port uses a :class:`.ModeMonitor` to compute the characteristic impedance and the port voltages and currents.""" center = list(self.center) if snap_center: center[self.injection_axis] = snap_center mode_mon = ModeMonitor( center=self.center, size=self.size, freqs=freqs, name=self._mode_monitor_name, colocate=False, mode_spec=self.mode_spec, store_fields_direction=self.direction, ) return [mode_mon]
[docs] def to_mode_solver(self, simulation: Simulation, freqs: FreqArray) -> ModeSolver: """Helper to create a :class:`.ModeSolver` instance.""" mode_solver = ModeSolver( simulation=simulation, plane=self.geometry, mode_spec=self.mode_spec, freqs=freqs, direction=self.direction, colocate=False, ) return mode_solver
[docs] def compute_voltage(self, sim_data: SimulationData) -> FreqDataArray: """Helper to compute voltage across the port.""" mode_data = sim_data[self._mode_monitor_name] voltage_coeffs = self._mode_voltage_coefficients(mode_data) amps = mode_data.amps fwd_amps = amps.sel(direction="+").squeeze() bwd_amps = amps.sel(direction="-").squeeze() return voltage_coeffs * (fwd_amps + bwd_amps)
[docs] def compute_current(self, sim_data: SimulationData) -> FreqDataArray: """Helper to compute current flowing through the port.""" mode_data = sim_data[self._mode_monitor_name] current_coeffs = self._mode_current_coefficients(mode_data) amps = mode_data.amps fwd_amps = amps.sel(direction="+").squeeze() bwd_amps = amps.sel(direction="-").squeeze() # In ModeData, fwd_amps and bwd_amps are not relative to # the direction fields are stored sign = 1.0 if self.direction == "-": sign = -1.0 return sign * current_coeffs * (fwd_amps - bwd_amps)
[docs] def compute_port_impedance( self, sim_mode_data: Union[SimulationData, ModeData] ) -> FreqModeDataArray: """Helper to compute impedance of port. The port impedance is computed from the transmission line mode, which should be TEM or at least quasi-TEM.""" impedance_calc = ImpedanceCalculator( voltage_integral=self.voltage_integral, current_integral=self.current_integral ) if isinstance(sim_mode_data, SimulationData): mode_data = sim_mode_data[self._mode_monitor_name] else: mode_data = sim_mode_data # Filter out unwanted modes to reduce impedance computation effort mode_data = mode_data._isel(mode_index=[self.mode_index]) impedance_array = impedance_calc.compute_impedance(mode_data) return impedance_array
[docs] def to_mesh_overrides(self) -> list[MeshOverrideStructure]: """Creates a list of :class:`.MeshOverrideStructure` for mesh refinement in the transverse plane of the port. The mode source requires at least 3 grid cells in the transverse dimensions, so these mesh overrides will be added to the simulation to ensure that this requirement is satisfied. """ dl = [None] * 3 for trans_axis in self.transverse_axes: dl[trans_axis] = self.size[trans_axis] / self.num_grid_cells return [ MeshOverrideStructure( geometry=Box(center=self.center, size=self.size), dl=dl, shadow=False, priority=-1, ) ]
@pd.validator("voltage_integral", "current_integral") def _validate_path_integrals_within_port(cls, val, values): """Raise ``ValidationError`` when the supplied path integrals are not within the port bounds.""" center = values["center"] size = values["size"] box = Box(center=center, size=size) if val and not bounds_contains( box.bounds, val.bounds, fp_eps, np.finfo(np.float32).smallest_normal ): raise ValidationError( f"'{cls.__name__}' must be setup with all path integrals defined within the bounds " f"of the port. Path bounds are '{val.bounds}', but port bounds are '{box.bounds}'." ) return val @pd.validator("current_integral", always=True) @skip_if_fields_missing(["voltage_integral"]) def _check_voltage_or_current(cls, val, values): """Raise validation error if both ``voltage_integral`` and ``current_integral`` were not provided.""" if values.get("voltage_integral") is None and val is None: raise ValidationError( "At least one of 'voltage_integral' or 'current_integral' must be provided." ) return val @pd.validator("current_integral", always=True) def _validate_current_integral_sign(cls, val, values): """ Validate that the sign of ``current_integral`` matches the port direction. """ if val is None: return val direction = values.get("direction") name = values.get("name") if val.sign != direction: raise ValidationError( f"'current_integral' sign must match the '{name}' direction '{direction}'." ) return val