from warnings import warn
import astropy.units as u
import numpy as np
from astropy.table import Table
from astropy.time import Time
from loguru import logger as log
from pint import DMconst
from pint.exceptions import MissingParameter
from pint.models.parameter import MJDParameter, floatParameter, prefixParameter
from pint.models.timing_model import DelayComponent, MissingParameter, MissingTOAs
from pint.toa_select import TOASelect
from pint.utils import split_prefixed_name, taylor_horner, taylor_horner_deriv
cmu = u.pc / u.cm**3 / u.MHz**2
[docs]class Chromatic(DelayComponent):
"""A base chromatic timing model with a constant chromatic index."""
def __init__(self):
super().__init__()
self.cm_value_funcs = []
self.cm_deriv_funcs = {}
self.alpha_deriv_funcs = {}
[docs] def chromatic_time_delay(self, cm, alpha, freq):
"""Return the chromatic time delay for a set of frequencies.
delay_chrom = cm * DMconst * (freq / 1 MHz)**alpha
"""
cmdelay = cm * DMconst * (freq / u.MHz) ** (-alpha)
return cmdelay.to(u.s)
def chromatic_type_delay(self, toas):
try:
bfreq = self._parent.barycentric_radio_freq(toas)
except AttributeError:
warn("Using topocentric frequency for chromatic delay!")
bfreq = toas.table["freq"]
cm = self.cm_value(toas)
alpha = self._parent["TNCHROMIDX"].quantity
return self.chromatic_time_delay(cm, alpha, bfreq)
[docs] def cm_value(self, toas):
"""Compute modeled CM value at given TOAs.
Parameters
----------
toas : `TOAs` object or TOA table(TOAs.table)
If given a TOAs object, it will use the whole TOA table in the
`TOAs` object.
Return
------
CM values at given TOAs in the unit of CM.
"""
toas_table = toas if isinstance(toas, Table) else toas.table
cm = np.zeros(len(toas_table)) * self._parent.CM.units
for cm_f in self.cm_value_funcs:
cm += cm_f(toas)
return cm
[docs] def d_delay_d_cmparam(self, toas, param_name, acc_delay=None):
"""Derivative of delay wrt to CM parameter.
Parameters
----------
toas : `pint.TOAs` object.
Input toas.
param_name : str
Derivative parameter name
acc_delay : `astropy.quantity` or `numpy.ndarray`
Accumulated delay values. This parameter is to keep the unified API,
but not used in this function.
"""
try:
bfreq = self._parent.barycentric_radio_freq(toas)
except AttributeError:
warn("Using topocentric frequency for dedispersion!")
bfreq = toas.table["freq"]
param_unit = getattr(self, param_name).units
d_cm_d_cmparam = np.zeros(toas.ntoas) * cmu / param_unit
alpha = self._parent["TNCHROMIDX"].quantity
for df in self.cm_deriv_funcs[param_name]:
d_cm_d_cmparam += df(toas, param_name)
return DMconst * d_cm_d_cmparam * (bfreq / u.MHz) ** (-alpha)
[docs] def register_cm_deriv_funcs(self, func, param):
"""Register the derivative function in to the deriv_func dictionaries.
Parameters
----------
func : callable
Calculates the derivative
param : str
Name of parameter the derivative is with respect to
"""
pn = self.match_param_aliases(param)
if pn not in list(self.cm_deriv_funcs.keys()):
self.cm_deriv_funcs[pn] = [func]
elif func in self.cm_deriv_funcs[pn]:
return
else:
self.cm_deriv_funcs[pn] += [func]
[docs]class ChromaticCM(Chromatic):
"""Simple chromatic delay model with a constant chromatic index.
This model uses Taylor expansion to model CM variation over time. It
can also be used for a constant CM.
Fitting for the chromatic index is not supported because the fit is too
unstable when fit simultaneously with the DM.
Parameters supported:
.. paramtable::
:class: pint.models.chromatic_model.ChromaticCM
"""
register = True
category = "chromatic_constant"
def __init__(self):
super().__init__()
self.add_param(
floatParameter(
name="CM",
units=cmu,
value=0.0,
description="Chromatic measure",
long_double=True,
convert_tcb2tdb=False,
)
)
self.add_param(
prefixParameter(
name="CM1",
units=cmu / u.year,
description="First order time derivative of the chromatic measure",
unit_template=self.CM_derivative_unit,
description_template=self.CM_derivative_description,
type_match="float",
long_double=True,
convert_tcb2tdb=False,
)
)
self.add_param(
MJDParameter(
name="CMEPOCH",
description="Epoch of CM measurement",
time_scale="tdb",
convert_tcb2tdb=False,
)
)
self.add_param(
floatParameter(
name="TNCHROMIDX",
units=u.dimensionless_unscaled,
value=4.0,
description="Chromatic measure index",
long_double=True,
convert_tcb2tdb=False,
)
)
self.cm_value_funcs += [self.base_cm]
self.delay_funcs_component += [self.constant_chromatic_delay]
[docs] def setup(self):
super().setup()
base_cms = list(self.get_prefix_mapping_component("CM").values())
base_cms += ["CM"]
for cm_name in base_cms:
self.register_deriv_funcs(self.d_delay_d_cmparam, cm_name)
self.register_cm_deriv_funcs(self.d_cm_d_CMs, cm_name)
[docs] def validate(self):
"""Validate the CM parameters input."""
super().validate()
# If CM1 is set, we need CMEPOCH
if (
self.CM1.value is not None
and self.CM1.value != 0.0
and self.CMEPOCH.value is None
):
if self._parent.PEPOCH.value is not None:
self.CMEPOCH.value = self._parent.PEPOCH.value
else:
raise MissingParameter(
"Chromatic",
"CMEPOCH",
"CMEPOCH or PEPOCH is required if CM1 or higher are set",
)
def CM_derivative_unit(self, n):
return f"pc cm^-3 MHz^-2 / yr^{n:d}" if n else "pc cm^-3 MHz^-2"
def CM_derivative_description(self, n):
return f"{n:d}'th time derivative of the chromatic measure"
[docs] def get_CM_terms(self):
"""Return a list of CM term values in the model: [CM, CM1, ..., CMn]"""
return [self.CM.quantity] + self._parent.get_prefix_list("CM", start_index=1)
def base_cm(self, toas):
cm = np.zeros(len(toas))
cm_terms = self.get_CM_terms()
if any(cmi.value != 0 for cmi in cm_terms[1:]):
CMEPOCH = self.CMEPOCH.value
if CMEPOCH is None:
# Should be ruled out by validate()
raise ValueError(
f"CMEPOCH not set but some derivatives are not zero: {cm_terms}"
)
else:
dt = (toas["tdbld"] - CMEPOCH) * u.day
dt_value = dt.to_value(u.yr)
else:
dt_value = np.zeros(len(toas), dtype=np.longdouble)
cm_terms_value = [c.value for c in cm_terms]
cm = taylor_horner(dt_value, cm_terms_value)
return cm * cmu
def alpha_value(self, toas):
return np.ones(len(toas)) * self.CMIDX.quantity
[docs] def constant_chromatic_delay(self, toas, acc_delay=None):
"""This is a wrapper function for interacting with the TimingModel class"""
return self.chromatic_type_delay(toas)
[docs] def print_par(self, format="pint"):
prefix_cm = list(self.get_prefix_mapping_component("CM").values())
cms = ["CM"] + prefix_cm
result = "".join(getattr(self, cm).as_parfile_line(format=format) for cm in cms)
if hasattr(self, "components"):
all_params = self.components["ChromaticCM"].params
else:
all_params = self.params
for pm in all_params:
if pm not in cms:
result += getattr(self, pm).as_parfile_line(format=format)
return result
[docs] def d_cm_d_CMs(self, toas, param_name, acc_delay=None):
"""Derivatives of CM wrt the CM taylor expansion coefficients."""
par = getattr(self, param_name)
if param_name == "CM":
order = 0
else:
pn, idxf, idxv = split_prefixed_name(param_name)
order = idxv
cms = self.get_CM_terms()
cm_terms = np.longdouble(np.zeros(len(cms)))
cm_terms[order] = np.longdouble(1.0)
if self.CMEPOCH.value is None:
if any(t.value != 0 for t in cms[1:]):
# Should be ruled out by validate()
raise ValueError(f"CMEPOCH is not set but {param_name} is not zero")
CMEPOCH = 0
else:
CMEPOCH = self.CMEPOCH.value
dt = (toas["tdbld"] - CMEPOCH) * u.day
dt_value = (dt.to(u.yr)).value
return taylor_horner(dt_value, cm_terms) * (cmu / par.units)
[docs] def change_cmepoch(self, new_epoch):
"""Change CMEPOCH to a new value and update CM accordingly.
Parameters
----------
new_epoch: float MJD (in TDB) or `astropy.Time` object
The new CMEPOCH value.
"""
if isinstance(new_epoch, Time):
new_epoch = Time(new_epoch, scale="tdb", precision=9)
else:
new_epoch = Time(new_epoch, scale="tdb", format="mjd", precision=9)
cmterms = [0.0 * u.Unit("")] + self.get_CM_terms()
if self.CMEPOCH.value is None:
if any(d.value != 0 for d in cmterms[2:]):
# Should be ruled out by validate()
raise ValueError(
f"CMEPOCH not set but some CM derivatives are not zero: {cmterms}"
)
self.CMEPOCH.value = new_epoch
cmepoch_ld = self.CMEPOCH.quantity.tdb.mjd_long
dt = (new_epoch.tdb.mjd_long - cmepoch_ld) * u.day
for n in range(len(cmterms) - 1):
cur_deriv = self.CM if n == 0 else getattr(self, f"CM{n}")
cur_deriv.value = taylor_horner_deriv(
dt.to(u.yr), cmterms, deriv_order=n + 1
)
self.CMEPOCH.value = new_epoch
[docs]class ChromaticCMX(Chromatic):
"""This class provides a CMX model - piecewise-constant chromatic variations with constant
chromatic index.
This model lets the user specify time ranges and fit for a different CMX value in each time range.
It should be used in combination with the `ChromaticCM` model. Specifically, TNCHROMIDX must be
set.
Parameters supported:
.. paramtable::
:class: pint.models.chromatic_model.ChromaticCMX
"""
register = True
category = "chromatic_cmx"
def __init__(self):
super().__init__()
self.add_CMX_range(None, None, cmx=0, frozen=False, index=1)
self.cm_value_funcs += [self.cmx_cm]
self.set_special_params(["CMX_0001", "CMXR1_0001", "CMXR2_0001"])
self.delay_funcs_component += [self.CMX_chromatic_delay]
[docs] def add_CMX_range(self, mjd_start, mjd_end, index=None, cmx=0, frozen=True):
"""Add CMX range to a chromatic model with specified start/end MJDs and CMX value.
Parameters
----------
mjd_start : float or astropy.quantity.Quantity or astropy.time.Time
MJD for beginning of CMX event.
mjd_end : float or astropy.quantity.Quantity or astropy.time.Time
MJD for end of CMX event.
index : int, None
Integer label for CMX event. If None, will increment largest used index by 1.
cmx : float or astropy.quantity.Quantity
Change in CM during CMX event.
frozen : bool
Indicates whether CMX will be fit.
Returns
-------
index : int
Index that has been assigned to new CMX event.
"""
#### Setting up the CMX title convention. If index is None, want to increment the current max CMX index by 1.
if index is None:
dct = self.get_prefix_mapping_component("CMX_")
index = np.max(list(dct.keys())) + 1
i = f"{int(index):04d}"
if mjd_end is not None and mjd_start is not None:
if mjd_end < mjd_start:
raise ValueError("Starting MJD is greater than ending MJD.")
elif mjd_start != mjd_end:
raise ValueError("Only one MJD bound is set.")
if int(index) in self.get_prefix_mapping_component("CMX_"):
raise ValueError(
f"Index '{index}' is already in use in this model. Please choose another."
)
if isinstance(cmx, u.quantity.Quantity):
cmx = cmx.to_value(cmu)
if isinstance(mjd_start, Time):
mjd_start = mjd_start.mjd
elif isinstance(mjd_start, u.quantity.Quantity):
mjd_start = mjd_start.value
if isinstance(mjd_end, Time):
mjd_end = mjd_end.mjd
elif isinstance(mjd_end, u.quantity.Quantity):
mjd_end = mjd_end.value
self.add_param(
prefixParameter(
name=f"CMX_{i}",
units=cmu,
value=cmx,
description="Dispersion measure variation",
parameter_type="float",
frozen=frozen,
convert_tcb2tdb=False,
)
)
self.add_param(
prefixParameter(
name=f"CMXR1_{i}",
units="MJD",
description="Beginning of CMX interval",
parameter_type="MJD",
time_scale="utc",
value=mjd_start,
convert_tcb2tdb=False,
)
)
self.add_param(
prefixParameter(
name=f"CMXR2_{i}",
units="MJD",
description="End of CMX interval",
parameter_type="MJD",
time_scale="utc",
value=mjd_end,
convert_tcb2tdb=False,
)
)
self.setup()
self.validate()
return index
[docs] def add_CMX_ranges(self, mjd_starts, mjd_ends, indices=None, cmxs=0, frozens=True):
"""Add CMX ranges to a dispersion model with specified start/end MJDs and CMXs.
Parameters
----------
mjd_starts : iterable of float or astropy.quantity.Quantity or astropy.time.Time
MJD for beginning of CMX event.
mjd_end : iterable of float or astropy.quantity.Quantity or astropy.time.Time
MJD for end of CMX event.
indices : iterable of int, None
Integer label for CMX event. If None, will increment largest used index by 1.
cmxs : iterable of float or astropy.quantity.Quantity, or float or astropy.quantity.Quantity
Change in CM during CMX event.
frozens : iterable of bool or bool
Indicates whether CMX will be fit.
Returns
-------
indices : list
Indices that has been assigned to new CMX events
"""
if len(mjd_starts) != len(mjd_ends):
raise ValueError(
f"Number of mjd_start values {len(mjd_starts)} must match number of mjd_end values {len(mjd_ends)}"
)
if indices is None:
indices = [None] * len(mjd_starts)
cmxs = np.atleast_1d(cmxs)
if len(cmxs) == 1:
cmxs = np.repeat(cmxs, len(mjd_starts))
if len(cmxs) != len(mjd_starts):
raise ValueError(
f"Number of mjd_start values {len(mjd_starts)} must match number of cmx values {len(cmxs)}"
)
frozens = np.atleast_1d(frozens)
if len(frozens) == 1:
frozens = np.repeat(frozens, len(mjd_starts))
if len(frozens) != len(mjd_starts):
raise ValueError(
f"Number of mjd_start values {len(mjd_starts)} must match number of frozen values {len(frozens)}"
)
#### Setting up the CMX title convention. If index is None, want to increment the current max CMX index by 1.
dct = self.get_prefix_mapping_component("CMX_")
last_index = np.max(list(dct.keys()))
added_indices = []
for mjd_start, mjd_end, index, cmx, frozen in zip(
mjd_starts, mjd_ends, indices, cmxs, frozens
):
if index is None:
index = last_index + 1
last_index += 1
elif index in list(dct.keys()):
raise ValueError(
f"Attempting to insert CMX_{index:04d} but it already exists"
)
added_indices.append(index)
i = f"{int(index):04d}"
if mjd_end is not None and mjd_start is not None:
if mjd_end < mjd_start:
raise ValueError("Starting MJD is greater than ending MJD.")
elif mjd_start != mjd_end:
raise ValueError("Only one MJD bound is set.")
if int(index) in dct:
raise ValueError(
f"Index '{index}' is already in use in this model. Please choose another."
)
if isinstance(cmx, u.quantity.Quantity):
cmx = cmx.to_value(u.pc / u.cm**3)
if isinstance(mjd_start, Time):
mjd_start = mjd_start.mjd
elif isinstance(mjd_start, u.quantity.Quantity):
mjd_start = mjd_start.value
if isinstance(mjd_end, Time):
mjd_end = mjd_end.mjd
elif isinstance(mjd_end, u.quantity.Quantity):
mjd_end = mjd_end.value
log.trace(f"Adding CMX_{i} from MJD {mjd_start} to MJD {mjd_end}")
self.add_param(
prefixParameter(
name=f"CMX_{i}",
units=cmu,
value=cmx,
description="Dispersion measure variation",
parameter_type="float",
frozen=frozen,
convert_tcb2tdb=False,
)
)
self.add_param(
prefixParameter(
name=f"CMXR1_{i}",
units="MJD",
description="Beginning of CMX interval",
parameter_type="MJD",
time_scale="utc",
value=mjd_start,
convert_tcb2tdb=False,
)
)
self.add_param(
prefixParameter(
name=f"CMXR2_{i}",
units="MJD",
description="End of CMX interval",
parameter_type="MJD",
time_scale="utc",
value=mjd_end,
convert_tcb2tdb=False,
)
)
self.setup()
self.validate()
return added_indices
[docs] def remove_CMX_range(self, index):
"""Removes all CMX parameters associated with a given index/list of indices.
Parameters
----------
index : float, int, list, np.ndarray
Number or list/array of numbers corresponding to CMX indices to be removed from model.
"""
if isinstance(index, (int, float, np.int64)):
indices = [index]
elif isinstance(index, (list, set, np.ndarray)):
indices = index
else:
raise TypeError(
f"index must be a float, int, set, list, or array - not {type(index)}"
)
for index in indices:
index_rf = f"{int(index):04d}"
for prefix in ["CMX_", "CMXR1_", "CMXR2_"]:
self.remove_param(prefix + index_rf)
self.validate()
[docs] def get_indices(self):
"""Returns an array of integers corresponding to CMX parameters.
Returns
-------
inds : np.ndarray
Array of CMX indices in model.
"""
inds = [int(p.split("_")[-1]) for p in self.params if "CMX_" in p]
return np.array(inds)
[docs] def setup(self):
super().setup()
# Get CMX mapping.
# Register the CMX derivatives
for prefix_par in self.get_params_of_type("prefixParameter"):
if prefix_par.startswith("CMX_"):
self.register_deriv_funcs(self.d_delay_d_cmparam, prefix_par)
self.register_cm_deriv_funcs(self.d_cm_d_CMX, prefix_par)
[docs] def validate(self):
"""Validate the CMX parameters."""
super().validate()
CMX_mapping = self.get_prefix_mapping_component("CMX_")
CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_")
CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_")
if CMX_mapping.keys() != CMXR1_mapping.keys():
# FIXME: report mismatch
raise ValueError(
"CMX_ parameters do not "
"match CMXR1_ parameters. "
"Please check your prefixed parameters."
)
if CMX_mapping.keys() != CMXR2_mapping.keys():
raise ValueError(
"CMX_ parameters do not "
"match CMXR2_ parameters. "
"Please check your prefixed parameters."
)
r1 = np.zeros(len(CMX_mapping))
r2 = np.zeros(len(CMX_mapping))
indices = np.zeros(len(CMX_mapping), dtype=np.int32)
for j, index in enumerate(CMX_mapping):
if (
getattr(self, f"CMXR1_{index:04d}").quantity is not None
and getattr(self, f"CMXR2_{index:04d}").quantity is not None
):
r1[j] = getattr(self, f"CMXR1_{index:04d}").quantity.mjd
r2[j] = getattr(self, f"CMXR2_{index:04d}").quantity.mjd
indices[j] = index
for j, index in enumerate(CMXR1_mapping):
if np.any((r1[j] > r1) & (r1[j] < r2)):
k = np.where((r1[j] > r1) & (r1[j] < r2))[0]
for kk in k.flatten():
log.warning(
f"Start of CMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with CMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})"
)
if np.any((r2[j] > r1) & (r2[j] < r2)):
k = np.where((r2[j] > r1) & (r2[j] < r2))[0]
for kk in k.flatten():
log.warning(
f"End of CMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with CMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})"
)
[docs] def validate_toas(self, toas):
CMX_mapping = self.get_prefix_mapping_component("CMX_")
CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_")
CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_")
bad_parameters = []
for k in CMXR1_mapping.keys():
if self._parent[CMX_mapping[k]].frozen:
continue
b = self._parent[CMXR1_mapping[k]].quantity.mjd * u.d
e = self._parent[CMXR2_mapping[k]].quantity.mjd * u.d
mjds = toas.get_mjds()
n = np.sum((b <= mjds) & (mjds < e))
if n == 0:
bad_parameters.append(CMX_mapping[k])
if bad_parameters:
raise MissingTOAs(bad_parameters)
def cmx_cm(self, toas):
condition = {}
tbl = toas.table
if not hasattr(self, "cmx_toas_selector"):
self.cmx_toas_selector = TOASelect(is_range=True)
CMX_mapping = self.get_prefix_mapping_component("CMX_")
CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_")
CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_")
for epoch_ind in CMX_mapping.keys():
r1 = getattr(self, CMXR1_mapping[epoch_ind]).quantity
r2 = getattr(self, CMXR2_mapping[epoch_ind]).quantity
condition[CMX_mapping[epoch_ind]] = (r1.mjd, r2.mjd)
select_idx = self.cmx_toas_selector.get_select_index(
condition, tbl["mjd_float"]
)
# Get CMX delays
cm = np.zeros(len(tbl)) * self._parent.CM.units
for k, v in select_idx.items():
cm[v] += getattr(self, k).quantity
return cm
[docs] def CMX_chromatic_delay(self, toas, acc_delay=None):
"""This is a wrapper function for interacting with the TimingModel class"""
return self.chromatic_type_delay(toas)
def d_cm_d_CMX(self, toas, param_name, acc_delay=None):
condition = {}
tbl = toas.table
if not hasattr(self, "cmx_toas_selector"):
self.cmx_toas_selector = TOASelect(is_range=True)
param = getattr(self, param_name)
cmx_index = param.index
CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_")
CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_")
r1 = getattr(self, CMXR1_mapping[cmx_index]).quantity
r2 = getattr(self, CMXR2_mapping[cmx_index]).quantity
condition = {param_name: (r1.mjd, r2.mjd)}
select_idx = self.cmx_toas_selector.get_select_index(
condition, tbl["mjd_float"]
)
cmx = np.zeros(len(tbl))
for k, v in select_idx.items():
cmx[v] = 1.0
return cmx * (u.pc / u.cm**3) / (u.pc / u.cm**3)
[docs] def print_par(self, format="pint"):
result = ""
CMX_mapping = self.get_prefix_mapping_component("CMX_")
CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_")
CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_")
sorted_list = sorted(CMX_mapping.keys())
for ii in sorted_list:
result += getattr(self, CMX_mapping[ii]).as_parfile_line(format=format)
result += getattr(self, CMXR1_mapping[ii]).as_parfile_line(format=format)
result += getattr(self, CMXR2_mapping[ii]).as_parfile_line(format=format)
return result