import numpy as np
import jax.numpy as jnp
import jax_dataclasses as jdc
from jaxtyping import Float, Array
from quantumspectra_2024.common.absorption import (
AbsorptionModel as Model,
AbsorptionSpectrum,
)
from quantumspectra_2024.models.mlj.MLJComputation import calculate_mlj_spectrum
[docs]
@jdc.pytree_dataclass(kw_only=True)
class MLJModel(Model):
"""A two-state two-mode MLJ model for absorption spectra.
Parameters
----------
start_energy : float
absorption spectrum's starting energy (wavenumbers).
end_energy : float
absorption spectrum's ending energy (wavenumbers).
num_points : int
absorption spectrum's number of points (unitless).
temperature_kelvin : float
system's temperature (Kelvin).
energy_gap : float
energy gap between the two states (wavenumbers).
disorder_meV : float
disorder in the system (meV).
basis_size : int
size of basis set (unitless).
mode_frequencies : Float[Array, "2"]
frequency per mode (wavenumbers).
mode_couplings : Float[Array, "2"]
excited state coupling per mode.
"""
#: system's temperature (Kelvin).
temperature_kelvin: float
#: energy gap between the two states (wavenumbers).
energy_gap: float
#: disorder in the system (meV).
disorder_meV: float
#: size of basis set (unitless).
basis_size: int = 20
#: frequency per mode (wavenumbers), must have exactly two.
mode_frequencies: Float[Array, "2"]
#: excited state coupling per mode, must have exactly two.
mode_couplings: Float[Array, "2"]
[docs]
def get_absorption(self) -> AbsorptionSpectrum:
"""Compute the absorption spectrum for the model.
See docs in `MLJComputation` to see how this is done.
Returns
-------
AbsorptionSpectrum
the model's parameterized absorption spectrum.
"""
# get absorption spectrum sample energies (x values)
sample_points = jnp.linspace(
float(self.start_energy), float(self.end_energy), int(self.num_points)
)
# get low and high frequency modes
lower_frequency, lower_coupling, higher_frequency, higher_coupling = (
self.get_low_high_frequency_modes()
)
# calculate absorption spectrum
spectrum = calculate_mlj_spectrum(
energy_gap=float(self.energy_gap),
high_freq_frequency=float(higher_frequency),
high_freq_coupling=float(higher_coupling),
low_freq_frequency=float(lower_frequency),
low_freq_coupling=float(lower_coupling),
temperature_kelvin=float(self.temperature_kelvin),
disorder_meV=float(self.disorder_meV),
basis_size=int(self.basis_size),
sample_points=np.array(sample_points),
)
# return as AbsorptionSpectrum dataclass
return AbsorptionSpectrum(
energies=sample_points,
intensities=jnp.array(spectrum),
)
[docs]
def get_low_high_frequency_modes(self) -> tuple[float, float, float, float]:
"""Extracts the low and high frequency modes and mode couplings from the model.
Model was already verified to have exactly two modes.
This function sorts the modes by frequency and returns the lowest and highest frequency modes.
This is useful for the `calculate_mlj_spectrum` function, which expects the low and high frequency modes explicitly.
Returns
-------
tuple[float, float, float, float]
low frequency, low coupling, high frequency, high coupling.
"""
mode_frequencies = np.array(self.mode_frequencies)
mode_couplings = np.array(self.mode_couplings)
sorted_frequency_indices = np.argsort(mode_frequencies)
sorted_frequencies = mode_frequencies[sorted_frequency_indices]
sorted_couplings = mode_couplings[sorted_frequency_indices]
# Assign to variables
lower_frequency = sorted_frequencies[0]
lower_coupling = sorted_couplings[0]
higher_frequency = sorted_frequencies[-1]
higher_coupling = sorted_couplings[-1]
return lower_frequency, lower_coupling, higher_frequency, higher_coupling
[docs]
def apply_electric_field(
self,
field_strength: float,
field_delta_dipole: float,
field_delta_polarizability: float,
) -> "MLJModel":
"""Applies an electric field to the model. Returns a new instance of the model.
Parameters
----------
field_strength : float
the strength of the electric field.
field_delta_dipole : float
the change in dipole moment due to the electric field.
field_delta_polarizability : float
the change in polarizability due to the electric field.
Returns
-------
MLJModel
the model with the electric field applied.
"""
dipole_energy_change = field_delta_dipole * field_strength * 1679.0870295
polarizability_energy_change = (
0.5 * (field_strength**2) * field_delta_polarizability * 559.91
)
field_energy_change = -1 * (dipole_energy_change + polarizability_energy_change)
return jdc.replace(
self,
energy_gap=self.energy_gap + field_energy_change,
)
[docs]
def _verify_modes(self):
"""Verifies that the model has exactly two modes.
Raises
------
ValueError
if the model does not have exactly two modes.
"""
if len(self.mode_frequencies) != 2 or len(self.mode_couplings) != 2:
raise ValueError("The MLJ model requires exactly two modes.")
def __post_init__(self):
self._verify_modes()