"""DM variations expressed as a sum of sinusoids."""
from warnings import warn
import astropy.units as u
import numpy as np
from loguru import logger as log
from pint import DMconst, dmu
from pint.exceptions import MissingParameter
from pint.models.dispersion_model import Dispersion
from pint.models.parameter import MJDParameter, prefixParameter
[docs]class DMWaveX(Dispersion):
"""
Fourier representation of DM variations.
Used for decomposition of DM noise into a series of sine/cosine components with the amplitudes as fitted parameters.
Parameters supported:
.. paramtable::
:class: pint.models.dmwavex.DMWaveX
To set up a DMWaveX model, users can use the `pint.utils` function `dmwavex_setup()` with either a list of frequencies or a choice
of harmonics of a base frequency determined by 2 * pi /Timespan
"""
register = True
category = "dmwavex"
def __init__(self):
super().__init__()
self.add_param(
MJDParameter(
name="DMWXEPOCH",
description="Reference epoch for Fourier representation of DM noise",
time_scale="tdb",
tcb2tdb_scale_factor=u.Quantity(1),
)
)
self.add_dmwavex_component(0.1, index=1, dmwxsin=0, dmwxcos=0, frozen=False)
self.set_special_params(["DMWXFREQ_0001", "DMWXSIN_0001", "DMWXCOS_0001"])
self.dm_value_funcs += [self.dmwavex_dm]
self.delay_funcs_component += [self.dmwavex_delay]
[docs] def add_dmwavex_component(
self, dmwxfreq, index=None, dmwxsin=0, dmwxcos=0, frozen=True
):
"""
Add DMWaveX component
Parameters
----------
dmwxfreq : float or astropy.quantity.Quantity
Base frequency for DMWaveX component
index : int, None
Interger label for DMWaveX component. If None, will increment largest used index by 1.
dmwxsin : float or astropy.quantity.Quantity
Sine amplitude for DMWaveX component
dmwxcos : float or astropy.quantity.Quantity
Cosine amplitude for DMWaveX component
frozen : iterable of bool or bool
Indicates whether DMWaveX parameters will be fit
Returns
-------
index : int
Index that has been assigned to new DMWaveX component
"""
#### If index is None, increment the current max DMWaveX index by 1. Increment using DMWXFREQ
if index is None:
dct = self.get_prefix_mapping_component("DMWXFREQ_")
index = np.max(list(dct.keys())) + 1
i = f"{int(index):04d}"
if int(index) in self.get_prefix_mapping_component("DMWXFREQ_"):
raise ValueError(
f"Index '{index}' is already in use in this model. Please choose another"
)
if isinstance(dmwxsin, u.quantity.Quantity):
dmwxsin = dmwxsin.to_value(dmu)
if isinstance(dmwxcos, u.quantity.Quantity):
dmwxcos = dmwxcos.to_value(dmu)
if isinstance(dmwxfreq, u.quantity.Quantity):
dmwxfreq = dmwxfreq.to_value(1 / u.d)
self.add_param(
prefixParameter(
name=f"DMWXFREQ_{i}",
description="Component frequency for Fourier representation of DM noise",
units="1/d",
value=dmwxfreq,
parameter_type="float",
tcb2tdb_scale_factor=u.Quantity(1),
)
)
self.add_param(
prefixParameter(
name=f"DMWXSIN_{i}",
description="Sine amplitudes for Fourier representation of DM noise",
units=dmu,
value=dmwxsin,
frozen=frozen,
parameter_type="float",
tcb2tdb_scale_factor=DMconst,
)
)
self.add_param(
prefixParameter(
name=f"DMWXCOS_{i}",
description="Cosine amplitudes for Fourier representation of DM noise",
units=dmu,
value=dmwxcos,
frozen=frozen,
parameter_type="float",
tcb2tdb_scale_factor=DMconst,
)
)
self.setup()
self.validate()
return index
[docs] def add_dmwavex_components(
self, dmwxfreqs, indices=None, dmwxsins=0, dmwxcoses=0, frozens=True
):
"""
Add DMWaveX components with specified base frequencies
Parameters
----------
dmwxfreqs : iterable of float or astropy.quantity.Quantity
Base frequencies for DMWaveX components
indices : iterable of int, None
Interger labels for DMWaveX components. If None, will increment largest used index by 1.
dmwxsins : iterable of float or astropy.quantity.Quantity
Sine amplitudes for DMWaveX components
dmwxcoses : iterable of float or astropy.quantity.Quantity
Cosine amplitudes for DMWaveX components
frozens : iterable of bool or bool
Indicates whether sine and cosine amplitudes of DMwavex components will be fit
Returns
-------
indices : list
Indices that have been assigned to new DMWaveX components
"""
if indices is None:
indices = [None] * len(dmwxfreqs)
dmwxsins = np.atleast_1d(dmwxsins)
dmwxcoses = np.atleast_1d(dmwxcoses)
if len(dmwxsins) == 1:
dmwxsins = np.repeat(dmwxsins, len(dmwxfreqs))
if len(dmwxcoses) == 1:
dmwxcoses = np.repeat(dmwxcoses, len(dmwxfreqs))
if len(dmwxsins) != len(dmwxfreqs):
raise ValueError(
f"Number of base frequencies {len(dmwxfreqs)} doesn't match number of sine ampltudes {len(dmwxsins)}"
)
if len(dmwxcoses) != len(dmwxfreqs):
raise ValueError(
f"Number of base frequencies {len(dmwxfreqs)} doesn't match number of cosine ampltudes {len(dmwxcoses)}"
)
frozens = np.atleast_1d(frozens)
if len(frozens) == 1:
frozens = np.repeat(frozens, len(dmwxfreqs))
if len(frozens) != len(dmwxfreqs):
raise ValueError(
"Number of base frequencies must match number of frozen values"
)
#### If indices is None, increment the current max DMWaveX index by 1. Increment using DMWXFREQ
dct = self.get_prefix_mapping_component("DMWXFREQ_")
last_index = np.max(list(dct.keys()))
added_indices = []
for dmwxfreq, index, dmwxsin, dmwxcos, frozen in zip(
dmwxfreqs, indices, dmwxsins, dmwxcoses, frozens
):
if index is None:
index = last_index + 1
last_index += 1
elif index in list(dct.keys()):
raise ValueError(
f"Attempting to insert DMWXFREQ_{index:04d} but it already exists"
)
added_indices.append(index)
i = f"{int(index):04d}"
if int(index) in dct:
raise ValueError(
f"Index '{index}' is already in use in this model. Please choose another"
)
if isinstance(dmwxfreq, u.quantity.Quantity):
dmwxfreq = dmwxfreq.to_value(u.d**-1)
if isinstance(dmwxsin, u.quantity.Quantity):
dmwxsin = dmwxsin.to_value(dmu)
if isinstance(dmwxcos, u.quantity.Quantity):
dmwxcos = dmwxcos.to_value(dmu)
log.trace(f"Adding DMWXSIN_{i} and DMWXCOS_{i} at frequency DMWXFREQ_{i}")
self.add_param(
prefixParameter(
name=f"DMWXFREQ_{i}",
description="Component frequency for Fourier representation of DM noise",
units="1/d",
value=dmwxfreq,
parameter_type="float",
tcb2tdb_scale_factor=u.Quantity(1),
)
)
self.add_param(
prefixParameter(
name=f"DMWXSIN_{i}",
description="Sine amplitude for Fourier representation of DM noise",
units=dmu,
value=dmwxsin,
parameter_type="float",
frozen=frozen,
tcb2tdb_scale_factor=DMconst,
)
)
self.add_param(
prefixParameter(
name=f"DMWXCOS_{i}",
description="Cosine amplitude for Fourier representation of DM noise",
units=dmu,
value=dmwxcos,
parameter_type="float",
frozen=frozen,
tcb2tdb_scale_factor=DMconst,
)
)
self.setup()
self.validate()
return added_indices
[docs] def remove_dmwavex_component(self, index):
"""
Remove all DMWaveX components associated with a given index or list of indices
Parameters
----------
index : float, int, list, np.ndarray
Number or list/array of numbers corresponding to DMWaveX indices to be removed from model.
"""
if isinstance(index, (int, float, np.int64)):
indices = [index]
elif isinstance(index, (list, set, np.ndarray)):
indices = index
else:
raise TypeError(
f"index most be a float, int, set, list, or array - not {type(index)}"
)
for index in indices:
index_rf = f"{int(index):04d}"
for prefix in ["DMWXFREQ_", "DMWXSIN_", "DMWXCOS_"]:
self.remove_param(prefix + index_rf)
self.validate()
[docs] def get_indices(self):
"""
Returns an array of intergers corresponding to DMWaveX component parameters using DMWXFREQs
Returns
-------
inds : np.ndarray
Array of DMWaveX indices in model.
"""
inds = [int(p.split("_")[-1]) for p in self.params if "DMWXFREQ_" in p]
return np.array(inds)
# Initialize setup
[docs] def setup(self):
super().setup()
# Get DMWaveX mapping and register DMWXSIN and DMWXCOS derivatives
for prefix_par in self.get_params_of_type("prefixParameter"):
if prefix_par.startswith("DMWXSIN_"):
self.register_deriv_funcs(self.d_delay_d_dmparam, prefix_par)
self.register_dm_deriv_funcs(self.d_dm_d_DMWXSIN, prefix_par)
if prefix_par.startswith("DMWXCOS_"):
self.register_deriv_funcs(self.d_delay_d_dmparam, prefix_par)
self.register_dm_deriv_funcs(self.d_dm_d_DMWXCOS, prefix_par)
self.dmwavex_freqs = list(
self.get_prefix_mapping_component("DMWXFREQ_").keys()
)
self.num_dmwavex_freqs = len(self.dmwavex_freqs)
[docs] def validate(self):
# Validate all the DMWaveX parameters
super().validate()
self.setup()
DMWXFREQ_mapping = self.get_prefix_mapping_component("DMWXFREQ_")
DMWXSIN_mapping = self.get_prefix_mapping_component("DMWXSIN_")
DMWXCOS_mapping = self.get_prefix_mapping_component("DMWXCOS_")
if DMWXFREQ_mapping.keys() != DMWXSIN_mapping.keys():
raise ValueError(
"DMWXFREQ_ parameters do not match DMWXSIN_ parameters."
"Please check your prefixed parameters"
)
if DMWXFREQ_mapping.keys() != DMWXCOS_mapping.keys():
raise ValueError(
"DMWXFREQ_ parameters do not match DMWXCOS_ parameters."
"Please check your prefixed parameters"
)
# if len(DMWXFREQ_mapping.keys()) != len(DMWXSIN_mapping.keys()):
# raise ValueError(
# "The number of DMWXFREQ_ parameters do not match the number of DMWXSIN_ parameters."
# "Please check your prefixed parameters"
# )
# if len(DMWXFREQ_mapping.keys()) != len(DMWXCOS_mapping.keys()):
# raise ValueError(
# "The number of DMWXFREQ_ parameters do not match the number of DMWXCOS_ parameters."
# "Please check your prefixed parameters"
# )
if DMWXSIN_mapping.keys() != DMWXCOS_mapping.keys():
raise ValueError(
"DMWXSIN_ parameters do not match DMWXCOS_ parameters."
"Please check your prefixed parameters"
)
if len(DMWXSIN_mapping.keys()) != len(DMWXCOS_mapping.keys()):
raise ValueError(
"The number of DMWXSIN_ and DMWXCOS_ parameters do not match"
"Please check your prefixed parameters"
)
wfreqs = np.zeros(len(DMWXFREQ_mapping))
for j, index in enumerate(DMWXFREQ_mapping):
if (getattr(self, f"DMWXFREQ_{index:04d}").value == 0) or (
getattr(self, f"DMWXFREQ_{index:04d}").quantity is None
):
raise ValueError(
f"DMWXFREQ_{index:04d} is zero or None. Please check your prefixed parameters"
)
if getattr(self, f"DMWXFREQ_{index:04d}").value < 0.0:
warn(f"Frequency DMWXFREQ_{index:04d} is negative")
wfreqs[j] = getattr(self, f"DMWXFREQ_{index:04d}").value
wfreqs.sort()
# if np.any(np.diff(wfreqs) <= (1.0 / (2.0 * 364.25))):
# warn("Frequency resolution is greater than 1/yr")
if self.DMWXEPOCH.value is None and self._parent is not None:
if self._parent.PEPOCH.value is None:
raise MissingParameter(
"DMWXEPOCH or PEPOCH are required if DMWaveX is being used"
)
else:
self.DMWXEPOCH.quantity = self._parent.PEPOCH.quantity
[docs] def validate_toas(self, toas):
return super().validate_toas(toas)
def dmwavex_dm(self, toas):
total_dm = np.zeros(toas.ntoas) * dmu
dmwave_freqs = self.get_prefix_mapping_component("DMWXFREQ_")
dmwave_sins = self.get_prefix_mapping_component("DMWXSIN_")
dmwave_cos = self.get_prefix_mapping_component("DMWXCOS_")
base_phase = toas.table["tdbld"].data * u.d - self.DMWXEPOCH.value * u.d
for idx, param in dmwave_freqs.items():
freq = getattr(self, param).quantity
dmwxsin = getattr(self, dmwave_sins[idx]).quantity
dmwxcos = getattr(self, dmwave_cos[idx]).quantity
arg = 2.0 * np.pi * freq * base_phase
total_dm += dmwxsin * np.sin(arg.value) + dmwxcos * np.cos(arg.value)
return total_dm
def dmwavex_delay(self, toas, acc_delay=None):
return self.dispersion_type_delay(toas)
def d_dm_d_DMWXSIN(self, toas, param, acc_delay=None):
par = getattr(self, param)
freq = getattr(self, f"DMWXFREQ_{int(par.index):04d}").quantity
base_phase = toas.table["tdbld"].data * u.d - self.DMWXEPOCH.value * u.d
arg = 2.0 * np.pi * freq * base_phase
deriv = np.sin(arg.value)
return deriv * dmu / par.units
def d_dm_d_DMWXCOS(self, toas, param, acc_delay=None):
par = getattr(self, param)
freq = getattr(self, f"DMWXFREQ_{int(par.index):04d}").quantity
base_phase = toas.table["tdbld"].data * u.d - self.DMWXEPOCH.value * u.d
arg = 2.0 * np.pi * freq * base_phase
deriv = np.cos(arg.value)
return deriv * dmu / par.units