Skip to content

litmus.litmusclass

litmus.py

Contains the main litmus object class, which acts as a user-friendly interface with the models statistical models and fitting procedure. In future versions, this will also give access to the GUI.

todo - This entire class to be re-done to take multiple models instead of multiple lightcurves - Possibly add hdf5 saving to chain output - Maybe add save_litmus() /w pickling? - Need to have better handling of the "fitting method inherit" feature, especially with refactor / redo -

LITMUS(fitproc: fitting_procedure = None)

Bases: logger

A front-facing UI class for interfacing with the fitting procedures.

Source code in litmus/litmusclass.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def __init__(self, fitproc: fitting_procedure = None):

    logger.__init__(self)
    # ----------------------------

    if fitproc is None:
        self.msg_err("Didn't set a fitting method, using GP_simple")
        self.model = models.GP_simple()

        self.msg_err("Didn't set a fitting method, using hessian scan")

        fitproc = fitting_methods.hessian_scan(stat_model=self.model)

    self.model = fitproc.stat_model
    self.fitproc = fitproc

    # ----------------------------
    self.lightcurves = []
    self.data = None

    self.Nsamples = 50_000
    self.samples = {}
    self.prior_samples = self.model.prior_sample(self.Nsamples)
    self.C = ChainConsumer()

    self.C.set_override(ChainConfig(smooth=0, linewidth=2, plot_cloud=True, shade_alpha=0.5))

    # self.C.add_chain(Chain(samples=DataFrame.from_dict(self.prior_samples), name="Prior", color='gray'))
    if self.fitproc.has_run:
        self.samples = self.fitproc.get_samples(self.Nsamples)
        self.samples = self.fitproc.get_samples(self.Nsamples)
        self.C.add_chain(Chain(samples=DataFrame.from_dict(self.samples), name="Lightcurves %i-%i"))
        self.msg_err("Warning! LITMUS object built on pre-run fitting_procedure. May have unexpected behaviour.")

    return
model = fitproc.stat_model instance-attribute
fitproc = fitproc instance-attribute
lightcurves = [] instance-attribute
data = None instance-attribute
Nsamples = 50000 instance-attribute
samples = {} instance-attribute
prior_samples = self.model.prior_sample(self.Nsamples) instance-attribute
C = ChainConsumer() instance-attribute
add_lightcurve(lc: lightcurve)

Add a lightcurve 'lc' to the LITMUS object

Source code in litmus/litmusclass.py
93
94
95
96
97
98
def add_lightcurve(self, lc: lightcurve):
    """
    Add a lightcurve 'lc' to the LITMUS object
    """
    self.lightcurves.append(lc)
    return
remove_lightcurve(i: int) -> None

Remove lightcurve of index 'i' from the LITMUS object

Source code in litmus/litmusclass.py
100
101
102
103
104
105
106
107
108
109
110
def remove_lightcurve(self, i: int) -> None:
    """
    Remove lightcurve of index 'i' from the LITMUS object
    """
    N = len(self.lightcurves)

    if i < N:
        del self.lightcurves[i]
    else:
        self.msg_err("Tried to delete lightcurve %i but only have %i lightcurves. Skipping" % (i, N))
    return
prefit(i=0, j=1)

Performs the full fit for the chosen stats model and fitting method.

Source code in litmus/litmusclass.py
114
115
116
117
118
119
120
121
122
def prefit(self, i=0, j=1):
    """
    Performs the full fit for the chosen stats model and fitting method.
    """

    lc_1, lc_2 = self.lightcurves[i], self.lightcurves[j]
    self.data = self.model.lc_to_data(lc_1, lc_2)

    self.fitproc.prefit(lc_1, lc_2)
fit(i=0, j=1) -> None

Performs the full fit for the chosen stats model and fitting method.

Source code in litmus/litmusclass.py
124
125
126
127
128
129
130
131
132
133
134
135
def fit(self, i=0, j=1) -> None:
    """
    Performs the full fit for the chosen stats model and fitting method.
    """

    lc_1, lc_2 = self.lightcurves[i], self.lightcurves[j]
    self.data = self.model.lc_to_data(lc_1, lc_2)

    self.fitproc.fit(lc_1, lc_2)

    self.samples = self.fitproc.get_samples(self.Nsamples)
    self.C.add_chain(Chain(samples=DataFrame.from_dict(self.samples), name="Lightcurves %i-%i" % (i, j)))
save_chain(path: str = None, headings: bool = True) -> None

Saves the litmus's output chains to a .csv file at "path" If headings=True (default) then the names of the parameters will be written to the first row of the tile

todo - this needs updating
Source code in litmus/litmusclass.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
def save_chain(self, path: str = None, headings: bool = True) -> None:
    """
    Saves the litmus's output chains to a .csv file at "path"
    If headings=True (default) then the names of the parameters will be written to the first row of the tile
    #todo - this needs updating
    """
    if path is None:
        path = "./%s_%s.csv" % (self.model.name, self.fitproc.name)
        if path[-4:] != ".csv": path += ".csv"

    rows = zip(*self.samples.values())
    with open(path, mode="w", newline="") as file:
        writer = csv.writer(file)
        # Write header
        if headings: writer.writerow(self.samples.keys())
        # Write rows
        writer.writerows(rows)
read_chain(path: str, header: _types.Iterable[str] | None = None)
todo needs updating
Source code in litmus/litmusclass.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def read_chain(self, path: str, header: _types.Iterable[str] | None = None):
    """
    #todo needs updating
    """
    # Reading the CSV into a DataFrame
    df = pd.read_csv(path)

    if header is None:
        keys = df.columns
    else:
        keys = header.copy()

    # Converting DataFrame to dictionary of numpy arrays
    out = {col: df[col].to_numpy() for col in keys}

    if out.keys() <= set(self.fitproc.stat_model.paramnames()):
        self.samples = out
        self.msg_run("Loaded chain /w headings", *keys)
    else:
        self.msg_err("Tried to load chain with different parameter names to model")
config(**kwargs)

Quick and easy way to pass arguments to the chainconsumer object. Allows editing while prote

Source code in litmus/litmusclass.py
176
177
178
179
180
181
def config(self, **kwargs):
    '''
    Quick and easy way to pass arguments to the chainconsumer object.
    Allows editing while prote
    '''
    self.C.set_override(ChainConfig(**kwargs))
plot_lightcurves(model_no: int = 0, Nsamples: int = 1, Tspan: None | list[float, float] = None, Nplot: int = 1024, dir: str | None = None, show: bool = True) -> matplotlib.figure.Figure()

Plots the interpolated lightcurves for one of the fitted models

Parameters:

Name Type Description Default
model_no int

Which model to plot the lightcurves for

0
Nsamples int

Number of posterior samples to draw from when plotting

1
Tspan None | list[float, float]

Span of time values to plot over. If None, will use the max / min times of lc_1 and lc_2

None
Nplot int

Number of points in the interpolated lightcurve

1024
dir str | None

If not None, will save to this filepath

None
show bool

If True, will plt.show() the plot

True
Source code in litmus/litmusclass.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
def plot_lightcurves(self, model_no: int =0, Nsamples: int = 1, Tspan: None | list[float, float] = None, Nplot: int = 1024,
                     dir: str | None = None, show: bool = True) -> matplotlib.figure.Figure():
    """
    Plots the interpolated lightcurves for one of the fitted models
    :param model_no: Which model to plot the lightcurves for
    :parameter Nsamples: Number of posterior samples to draw from when plotting
    :parameter Tspan: Span of time values to plot over. If None, will use the max / min times of lc_1 and lc_2
    :parameter Nplot: Number of points in the interpolated lightcurve
    :parameter dir: If not None, will save to this filepath
    :parameter show: If True, will plt.show() the plot
    """

    self.msg_err("plot_lightcurve() not yet implemented")
    fig = plt.figure()
    return fig
plot_parameters(model_no: int | None = None, Nsamples: int = None, CC_kwargs: dict = {}, show: bool = True, prior_extents: bool = False, dir: str | None = None) -> matplotlib.figure.Figure

Creates a nicely formatted chainconsumer plot of the parameters

Parameters:

Name Type Description Default
model_no int | None

Which model to plot the lightcurves for. If None, will plot for all

None
Nsamples int

Number of posterior samples to draw from when plotting

None
CC_kwargs dict

Keyword arguments to pass to the chainconsumer constructor

{}
show bool

If True, will show the plot

True
prior_extents bool

If True, will use the model prior range for the axes limits (Defaults to false if multiple models used)

False
dir str | None

If not None, will save to this filepath Returns the matplotlib figure # todo - refactor for multi-model implementation

None
Source code in litmus/litmusclass.py
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def plot_parameters(self, model_no: int | None = None, Nsamples: int = None, CC_kwargs: dict = {},
                    show: bool = True,
                    prior_extents: bool = False, dir: str | None = None) -> matplotlib.figure.Figure:
    """
    Creates a nicely formatted chainconsumer plot of the parameters
    :param model_no: Which model to plot the lightcurves for. If None, will plot for all
    :param Nsamples: Number of posterior samples to draw from when plotting
    :parameter CC_kwargs: Keyword arguments to pass to the chainconsumer constructor
    :parameter show: If True, will show the plot
    :parameter prior_extents: If True, will use the model prior range for the axes limits (Defaults to false if multiple models used)
    :parameter dir: If not None, will save to this filepath
    Returns the matplotlib figure

    # todo - refactor for multi-model implementation
    """

    if Nsamples is not None and Nsamples != self.Nsamples:
        C = ChainConsumer()
        samples = self.fitproc.get_samples(Nsamples, **CC_kwargs)
        C.add_chain(Chain(samples=DataFrame.from_dict(samples), name='samples'))
    else:
        C = self.C

    if prior_extents:
        _config = PlotConfig(extents=self.model.prior_ranges, summarise=True,
                             **CC_kwargs)
    else:
        _config = PlotConfig(summarise=True,
                             **CC_kwargs)
    C.plotter.set_config(_config)
    params_toplot = [param for param in self.model.free_params() if self.samples[param].ptp() != 0]
    if len(params_toplot) == 0:
        fig = plt.figure()
        if show: plt.show()
        return fig

    try:
        fig = C.plotter.plot(columns=params_toplot,
                             )
    except:
        fig = plt.figure()
        fig.text(0.5, 0.5, "Something wrong with plotter")
    fig.tight_layout()
    if show: fig.show()

    if dir is not None:
        plt.savefig(dir)

    return fig
lag_plot(Nsamples: int = None, show: bool = True, extras: bool = True, prior_extents=False, dir: str | None = None) -> matplotlib.figure.Figure

Creates a nicely formatted chainconsumer plot of the marginalized lag plot

Parameters:

Name Type Description Default
Nsamples int

Number of posterior samples to draw from when plotting

None
show bool

If True, will show the plot

True
extras bool

If True, will add any fitting method specific extras to the plot

True
dir str | None

If not None, will save to this filepath

None
prior_extents

If True, will use the model prior range for the axes limits (Defaults to false if multiple models used) Returns the matplotlib figure

False
Source code in litmus/litmusclass.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
def lag_plot(self, Nsamples: int = None, show: bool = True, extras: bool = True, prior_extents=False,
             dir: str | None = None, ) -> matplotlib.figure.Figure:
    """
    Creates a nicely formatted chainconsumer plot of the marginalized lag plot
    :param Nsamples: Number of posterior samples to draw from when plotting
    :parameter show: If True, will show the plot
    :parameter extras: If True, will add any fitting method specific extras to the plot
    :parameter dir: If not None, will save to this filepath
    :parameter prior_extents: If True, will use the model prior range for the axes limits (Defaults to false if multiple models used)

    Returns the matplotlib figure
    """
    if 'lag' not in self.model.free_params():
        self.msg_err("Can't plot lags for a model without lags.")
        return

    if Nsamples is not None and Nsamples != self.Nsamples:
        C = ChainConsumer()
        samples = self.fitproc.get_samples(Nsamples)
        C.add_chain(Chain(samples=DataFrame.from_dict(samples), name="lags"))
    else:
        C = self.C

    _config = PlotConfig(extents=self.model.prior_ranges, summarise=True)
    C.plotter.set_config(_config)
    fig = C.plotter.plot_distributions(columns=['lag'], figsize=(8, 4))
    if prior_extents: fig.axes[0].set_xlim(*self.model.prior_ranges['lag'])
    fig.axes[0].set_ylim(*fig.axes[0].get_ylim())
    fig.tight_layout()

    fig.axes[0].grid()

    # Method specific plotting of fun stuff
    if extras:
        if isinstance(self.fitproc, fitting_methods.hessian_scan):
            X, logY = self.fitproc._get_slices('lags', 'logZ')

            if self.fitproc.interp_scale == 'linear':
                Y = np.exp(logY - logY.max())
                Y /= np.trapz(Y, X)
                fig.axes[0].plot(X, Y)

            elif self.fitproc.interp_scale == 'log':
                Xterp = np.linspace(*self.model.prior_ranges['lag'], self.Nsamples)
                logYterp = np.interp(Xterp, X, logY, left=logY[0], right=logY[-1])
                Yterp = np.exp(logYterp - logYterp.max())
                Yterp /= np.trapz(Yterp, Xterp)
                fig.axes[0].plot(Xterp, Yterp)

            plt.scatter(self.fitproc.lags, np.zeros_like(self.fitproc.lags), c='red', s=20)
            plt.scatter(X, np.zeros_like(X), c='black', s=20)
    if dir is not None:
        plt.savefig(dir)
    if show: fig.show()
    return (fig)
diagnostic_plots(dir: str | None = None, show: bool = False, **kwargs)

Generates a diagnostic plot window

Parameters:

Name Type Description Default
dir str | None

If not None, will save to this filepath

None
show bool

If True, will show the plot If dir!=None, will plt.savefig to the filepath 'dir' with **kwargs

False
Source code in litmus/litmusclass.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
def diagnostic_plots(self, dir: str | None = None, show: bool = False, **kwargs):
    """
    Generates a diagnostic plot window
    :param dir: If not None, will save to this filepath
    :param show: If True, will show the plot

    If dir!=None, will plt.savefig to the filepath 'dir' with **kwargs
    """
    if hasattr(self.fitproc, "diagnostics"):
        self.fitproc.diagnostics()
    else:
        self.msg_err("diagnostic_plots() not yet implemented for fitting method %s" % (self.fitproc.name))

    if dir is not None:
        plt.savefig(dir, **kwargs)

    if show: plt.show()

    return

suppress_stdout()

Source code in litmus/_utils.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
@contextmanager
def suppress_stdout():
    # Duplicate the original stdout file descriptor to restore later
    original_stdout_fd = os.dup(sys.stdout.fileno())

    # Open devnull file and redirect stdout to it
    with open(os.devnull, 'w') as devnull:
        os.dup2(devnull.fileno(), sys.stdout.fileno())
        try:
            yield
        finally:
            # Restore original stdout from the duplicated file descriptor
            os.dup2(original_stdout_fd, sys.stdout.fileno())
            # Close the duplicated file descriptor
            os.close(original_stdout_fd)

isiter(x: any) -> bool

Checks to see if an object is itterable

Source code in litmus/_utils.py
50
51
52
53
54
55
56
57
58
59
60
61
def isiter(x: any) -> bool:
    """
    Checks to see if an object is itterable
    """
    if type(x) == dict:
        return len(x[list(x.keys())[0]]) > 1
    try:
        iter(x)
    except:
        return (False)
    else:
        return (True)

isiter_dict(DICT: dict) -> bool

like isiter but for a dictionary. Checks only the first element in DICT.keys

Source code in litmus/_utils.py
64
65
66
67
68
69
70
71
72
73
def isiter_dict(DICT: dict) -> bool:
    """
    like isiter but for a dictionary. Checks only the first element in DICT.keys
    """

    key = list(DICT.keys())[0]
    if isiter(DICT[key]):
        return True
    else:
        return False

dict_dim(DICT: dict) -> (int, int)

Checks the first element of a dictionary and returns its length

Source code in litmus/_utils.py
76
77
78
79
80
81
82
83
84
85
def dict_dim(DICT: dict) -> (int, int):
    """
    Checks the first element of a dictionary and returns its length
    """

    if isiter_dict(DICT):
        firstkey = list(DICT.keys())[0]
        return (len(list(DICT.keys())), len(DICT[firstkey]))
    else:
        return (len(list(DICT.keys())), 1)

dict_pack(DICT: dict, keys=None, recursive=True, H=None, d0={}) -> np.array

Packs a dictionary into an array format

Parameters:

Name Type Description Default
DICT dict

the dict to unpack

required
keys

the order in which to index the keyed elements. If none, will use DICT.keys(). Can be partial

None
recursive

whether to recurse into arrays

True
H

Matrix to scale parameters by

None
d0

Value to offset by before packing

{}

Returns:

Type Description
array

(nkeys x len_array) np.arrayobject X = H (d-d0)

Source code in litmus/_utils.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def dict_pack(DICT: dict, keys=None, recursive=True, H=None, d0={}) -> np.array:
    """
    Packs a dictionary into an array format
    :param DICT: the dict to unpack
    :param keys: the order in which to index the keyed elements. If none, will use DICT.keys(). Can be partial
    :param recursive: whether to recurse into arrays
    :param H: Matrix to scale parameters by
    :param d0: Value to offset by before packing
    :return: (nkeys x len_array) np.arrayobject

    X = H (d-d0)
    """

    nokeys = True if keys is None else 0
    keys = keys if keys is not None else DICT.keys()

    if d0 is {}: d0 = {key:0 for key in keys}

    for key in keys:
        if key in DICT.keys() and key not in d0.keys(): d0 |= {key: 0.0}

    if recursive and type(list(DICT.values())[0]) == dict:
        out = np.array(
            [dict_pack(DICT[key] - d0[key], keys=keys if not nokeys else None, recursive=recursive) for key in keys])
    else:
        if isiter(DICT[list(keys)[0]]):
            out = np.array([[DICT[key][i] - d0[key] for i in range(dict_dim(DICT)[1])] for key in keys])
        else:
            out = np.array([DICT[key] - d0[key] for key in keys])

    return (out)

dict_unpack(X: np.array, keys: [str], recursive=True, Hinv=None, x0=None) -> np.array

Unpacks an array into a dict

Parameters:

Name Type Description Default
X array

Array to unpack

required
keys [str]

keys to unpack with

required

Returns:

Type Description
array

Hinv(X) + x0

Source code in litmus/_utils.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def dict_unpack(X: np.array, keys: [str], recursive=True, Hinv=None, x0=None) -> np.array:
    """
    Unpacks an array into a dict
    :param X: Array to unpack
    :param keys: keys to unpack with
    :return:

    Hinv(X) + x0
    """
    if Hinv is not None: assert Hinv.shape[0] == len(keys), "Size of H must be equal to number of keys in dict_unpack"

    if recursive and isiter(X[0]):
        out = {key: dict_unpack(X[i], keys, recursive) for i, key in enumerate(list(keys))}
    else:
        X = X.copy()
        if Hinv is not None:
            X = np.dot(Hinv, X)
        if x0 is not None:
            X += x0
        out = {key: X[i] for i, key in enumerate(list(keys))}

    return (out)

dict_sortby(A: dict, B: dict, match_only=True) -> dict

Sorts dict A to match keys of dict B.

Parameters:

Name Type Description Default
A dict

Dict to be sorted

required
B dict

Dict whose keys are will provide the ordering

required
match_only

If true, returns only for keys common to both A and B. Else, append un-sorted entries to end

True

Returns:

Type Description
dict

{key: A[key] for key in B if key in A}

Source code in litmus/_utils.py
147
148
149
150
151
152
153
154
155
156
157
158
159
def dict_sortby(A: dict, B: dict, match_only=True) -> dict:
    """
    Sorts dict A to match keys of dict B.

    :param A: Dict to be sorted
    :param B: Dict whose keys are will provide the ordering
    :param match_only: If true, returns only for keys common to both A and B. Else, append un-sorted entries to end
    :return: {key: A[key] for key in B if key in A}
    """
    out = {key: A[key] for key in B if key in A}
    if not match_only:
        out |= {key: A[key] for key in A if key not in B}
    return (out)

dict_extend(A: dict, B: dict = None) -> dict

Extends all single-length entries of a dict to match the length of a non-singular element

Parameters:

Name Type Description Default
A dict

Dictionary whose elements are to be extended

required
B dict

(optional) the array to extend by, equivalent to dict_extend(A|B)

None

Returns:

Type Description
dict

Dict A with any singleton elements extended to the longest entry in A or B

Source code in litmus/_utils.py
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def dict_extend(A: dict, B: dict = None) -> dict:
    """
    Extends all single-length entries of a dict to match the length of a non-singular element
    :param A: Dictionary whose elements are to be extended
    :param B: (optional) the array to extend by, equivalent to dict_extend(A|B)
    :return: Dict A with any singleton elements extended to the longest entry in A or B
    """

    out = A.copy()
    if B is not None: out |= B

    to_extend = [key for key in out if not isiter(out[key])]
    to_leave = [key for key in out if isiter(out[key])]

    if len(to_extend) == 0: return out
    if len(to_leave) == 0: return out

    N = len(out[to_leave[0]])
    for key in to_leave[1:]:
        assert len(out[key]) == N, "Tried to dict_extend() a dictionary with inhomogeneous lengths"

    for key in to_extend:
        out[key] = np.array([A[key]] * N)

    return (out)

dict_combine(X: [dict]) -> {str: [float]}

Combines an array, list etc. of dictionaries into a dictionary of arrays

Parameters:

Name Type Description Default
X [dict]

1D Iterable of dicts

required

Returns:

Type Description
{str: [float]}

Dict of 1D iterables

Source code in litmus/_utils.py
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
def dict_combine(X: [dict]) -> {str: [float]}:
    """
    Combines an array, list etc. of dictionaries into a dictionary of arrays

    :param X: 1D Iterable of dicts
    :return: Dict of 1D iterables
    """

    N = len(X)
    keys = X[0].keys()

    out = {key: np.zeros(N) for key in keys}
    for n in range(N):
        for key in keys:
            out[key][n] = X[n][key]
    return (out)

dict_divide(X: dict) -> [dict]

Splits dict of arrays into array of dicts. Opposite of dict_combine

Parameters:

Name Type Description Default
X dict

Dict of 1D iterables

required

Returns:

Type Description
[dict]

1D Iterable of dicts

Source code in litmus/_utils.py
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def dict_divide(X: dict) -> [dict]:
    """
    Splits dict of arrays into array of dicts. Opposite of dict_combine

    :param X: Dict of 1D iterables
    :return: 1D Iterable of dicts
    """

    keys = list(X.keys())
    N = len(X[keys[0]])

    out = [{key: X[key][i] for key in X} for i in range(N)]

    return (out)

dict_split(X: dict, keys: [str]) -> (dict, dict)

Splits a dict in two based on keys

Parameters:

Name Type Description Default
X dict

Dict to be split into A,B

required
keys [str]

Keys to be present in A, but not in B

required

Returns:

Type Description
(dict, dict)

tuple of dicts (A,B)

Source code in litmus/_utils.py
223
224
225
226
227
228
229
230
231
232
233
234
235
def dict_split(X: dict, keys: [str]) -> (dict, dict):
    """
    Splits a dict in two based on keys

    :param X: Dict to be split into A,B
    :param keys: Keys to be present in A, but not in B
    :return: tuple of dicts (A,B)
    """
    assert type(X) is dict, "input to dict_split() must be of type dict"
    assert isiter(keys) and type(keys[0])==str, "in dict_split() keys must be list of strings"
    A = {key: X[key] for key in keys}
    B = {key: X[key] for key in X.keys() if key not in keys}
    return (A, B)

pack_function(func, packed_keys: [str], fixed_values: dict = {}, invert: bool = False, jit: bool = False, H: np.array = None, d0: dict = {}) -> _types.FunctionType

Re-arranges a function that takes dict arguments to tak array-like arguments instead, so as to be autograd friendly Takes a function f(D:dict, arg, kwargs) and returns f(X, D2, args, **kwargs), D2 is all elements of D not listed in 'packed_keys' or fixed_values.

Parameters:

Name Type Description Default
func

Function to be unpacked

required
packed_keys [str]

Keys in 'D' to be packed in an array

required
fixed_values dict

Elements of 'D' to be fixed

{}
invert bool

If true, will 'flip' the function upside down

False
jit bool

If true, will 'jit' the function

False
H array

(optional) scaling matrix to reparameterize H with

None
d0 dict

(optional) If given, will center the reparameterized function at x0

{}
Source code in litmus/_utils.py
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def pack_function(func, packed_keys: ['str'], fixed_values: dict = {}, invert: bool = False, jit: bool = False,
                  H: np.array = None, d0: dict = {}) -> _types.FunctionType:
    """
    Re-arranges a function that takes dict arguments to tak array-like arguments instead, so as to be autograd friendly
    Takes a function f(D:dict, *arg, **kwargs) and returns f(X, D2, *args, **kwargs), D2 is all elements of D not
    listed in 'packed_keys' or fixed_values.

    :param func: Function to be unpacked
    :param packed_keys: Keys in 'D' to be packed in an array
    :param fixed_values: Elements of 'D' to be fixed
    :param invert:  If true, will 'flip' the function upside down
    :param jit: If true, will 'jit' the function
    :param H: (optional) scaling matrix to reparameterize H with
    :param d0: (optional) If given, will center the reparameterized  function at x0
    """

    if H is not None:
        assert H.shape[0] == len(packed_keys), "Scaling matrix H must be same length as packed_keys"
    else:
        H = jnp.eye(len(packed_keys))
    d0 = {key: 0.0 for key in packed_keys} | d0
    x0 = dict_pack(d0, packed_keys)

    # --------

    sign = -1 if invert else 1

    # --------
    def new_func(X, unpacked_params={}, *args, **kwargs):
        X = jnp.dot(H, X - x0)
        packed_dict = {key: x for key, x in zip(packed_keys, X)}
        packed_dict |= unpacked_params
        packed_dict |= fixed_values

        out = func(packed_dict, *args, **kwargs)
        return (sign * out)

    # --------
    if jit: new_func = jax.jit(new_func)

    return (new_func)

randint()

Quick utility to generate a random integer

Source code in litmus/_utils.py
287
288
289
290
291
def randint():
    """
    Quick utility to generate a random integer
    """
    return (np.random.randint(0, sys.maxsize // 1024))