Using callbacks in objectives¶
This notebook explains how to use a callback in an objective function. For details on the Callback
class, see the API reference. Potential use cases for this are:
Plotting some outputs at each iteration of the optimization
Saving internal variables to plot once the optimization is complete
Some objectives have “internal callbacks” which are not intended to be user facing. These are standard callbacks that can be used to plot the results of an optimization by using DataFit.plot_fit_results()
. For user-facing callbacks, users should create their own callback objects and call them directly for plotting, as demonstrated in this notebook.
Creating a custom callback¶
To implement a custom callback, create a class that inherits from iwp.callbacks.Callback
and calls some specific functions. See the documentation for iwp.callbacks.Callback
for more information on the available functions and their expected inputs.
import ionworkspipeline as iwp
import matplotlib.pyplot as plt
import pybamm
import numpy as np
import pandas as pd
/home/docs/checkouts/readthedocs.org/user_builds/ionworks-ionworkspipeline/envs/v0.8.2/lib/python3.12/site-packages/pybtex/plugin/__init__.py:26: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.
import pkg_resources
class MyCallback(iwp.callbacks.Callback):
def __init__(self):
super().__init__()
# Implement our own iteration counter
self.iter = 0
def on_objective_build(self, logs):
self.data_ = logs["data"]
def on_run_iteration(self, logs):
# Print some information at each iteration
inputs = logs["inputs"]
V_model = logs["outputs"]["Voltage [V]"]
V_data = self.data_["Voltage [V]"]
# calculate RMSE, note this is not necessarily the cost function used in the optimization
rmse = np.sqrt(np.nanmean((V_model - V_data) ** 2))
print(f"Iteration: {self.iter}, Inputs: {inputs}, RMSE: {rmse}")
self.iter += 1
def on_datafit_finish(self, logs):
self.fit_results_ = logs
def plot_fit_results(self):
"""
Plot the fit results.
"""
data = self.data_
fit = self.fit_results_["outputs"]
fit_results = {
"data": (data["Time [s]"], data["Voltage [V]"]),
"fit": (fit["Time [s]"], fit["Voltage [V]"]),
}
markers = {"data": "o", "fit": "--"}
colors = {"data": "k", "fit": "tab:red"}
fig, ax = plt.subplots()
for name, (t, V) in fit_results.items():
ax.plot(
t,
V,
markers[name],
label=name,
color=colors[name],
mfc="none",
linewidth=2,
)
ax.grid(alpha=0.5)
ax.set_xlabel("Time [s]")
ax.set_ylabel("Voltage [V]")
ax.legend()
return fig, ax
To use this callback, we generate synthetic data for a current-driven experiment and fit a SPM using the CurrentDriven
objective.
model = pybamm.lithium_ion.SPM()
parameter_values = iwp.ParameterValues("Chen2020")
t = np.linspace(0, 3600, 1000)
sim = iwp.Simulation(model, parameter_values=parameter_values, t_eval=t, t_interp=t)
sim.solve()
data = pd.DataFrame(
{x: sim.solution[x].entries for x in ["Time [s]", "Current [A]", "Voltage [V]"]}
)
# In this example we just fit the diffusivity in the positive electrode
parameters = {
"Positive particle diffusivity [m2.s-1]": iwp.Parameter("D_s", initial_value=1e-15),
}
# Create the callback
callback = MyCallback()
objective = iwp.objectives.CurrentDriven(
data, options={"model": model}, callbacks=callback
)
current_driven = iwp.DataFit(objective, parameters=parameters)
# make sure we're not accidentally initializing with the correct values by passing
# them in
params_for_pipeline = {k: v for k, v in parameter_values.items() if k not in parameters}
results = current_driven.run(params_for_pipeline)
Iteration: 0, Inputs: {'D_s': 1.0}, RMSE: 0.15909623870900977
Iteration: 1, Inputs: {'D_s': 1.0}, RMSE: 0.15909623870900977
Iteration: 2, Inputs: {'D_s': 2.0}, RMSE: 0.06447715147451864
Iteration: 3, Inputs: {'D_s': 0.0}, RMSE: 9999999996.444777
Iteration: 4, Inputs: {'D_s': 1.5000000000157114}, RMSE: 0.10181491874307463
Iteration: 5, Inputs: {'D_s': 2.25}, RMSE: 0.05119216091391999
Iteration: 6, Inputs: {'D_s': 2.35}, RMSE: 0.046565506353694656
Iteration: 7, Inputs: {'D_s': 2.45}, RMSE: 0.04225747901488892
Iteration: 8, Inputs: {'D_s': 2.5500000000000003}, RMSE: 0.0382369172522125
Iteration: 9, Inputs: {'D_s': 2.6500000000000004}, RMSE: 0.03447343630137423
Iteration: 10, Inputs: {'D_s': 2.79142135623731}, RMSE: 0.02954660806606048
Iteration: 11, Inputs: {'D_s': 2.89142135623731}, RMSE: 0.026302653676140705
Iteration: 12, Inputs: {'D_s': 2.99142135623731}, RMSE: 0.023247876819045724
Iteration: 13, Inputs: {'D_s': 3.13284271247462}, RMSE: 0.019212560615569665
Iteration: 14, Inputs: {'D_s': 3.33284271247462}, RMSE: 0.014012802298124988
Iteration: 15, Inputs: {'D_s': 3.43284271247462}, RMSE: 0.011610464754846384
Iteration: 16, Inputs: {'D_s': 3.5328427124746202}, RMSE: 0.009326575370409363
Iteration: 17, Inputs: {'D_s': 3.67426406871193}, RMSE: 0.006283353730401495
[IDAS ERROR] IDASolve
At t = 0 and h = 1e-09, the corrector convergence failed repeatedly or with |h| = hmin.
Iteration: 18, Inputs: {'D_s': 3.8273539842695685}, RMSE: 0.003212959554824122
Iteration: 19, Inputs: {'D_s': 3.9273539842695686}, RMSE: 0.0013219099026026695
Iteration: 20, Inputs: {'D_s': 4.027353984269569}, RMSE: 0.0004853044690214084
Iteration: 21, Inputs: {'D_s': 4.168775340506878}, RMSE: 0.0029058988942563626
Iteration: 22, Inputs: {'D_s': 4.092118921623178}, RMSE: 0.0016110869071908407
Iteration: 23, Inputs: {'D_s': 4.055618610280126}, RMSE: 0.000979193838405659
Iteration: 24, Inputs: {'D_s': 4.002353984269568}, RMSE: 4.172274197829926e-05
Iteration: 25, Inputs: {'D_s': 3.9773539842695684}, RMSE: 0.00040826019451056344
Iteration: 26, Inputs: {'D_s': 4.012353984269568}, RMSE: 0.000219439159215988
Iteration: 27, Inputs: {'D_s': 3.9923539842695686}, RMSE: 0.0001382617906831725
Iteration: 28, Inputs: {'D_s': 4.007301201272821}, RMSE: 0.00012954888597199697
Iteration: 29, Inputs: {'D_s': 3.9998539842695684}, RMSE: 8.413118630636302e-06
Iteration: 30, Inputs: {'D_s': 4.004822665489002}, RMSE: 8.545122472113428e-05
Iteration: 31, Inputs: {'D_s': 4.001353984269568}, RMSE: 2.4364576863860712e-05
Iteration: 32, Inputs: {'D_s': 4.000353984269568}, RMSE: 9.229044043049878e-06
Iteration: 33, Inputs: {'D_s': 3.9989397707071945}, RMSE: 2.1405422284535674e-05
Iteration: 34, Inputs: {'D_s': 4.00085321410753}, RMSE: 1.6097266099670858e-05
Iteration: 35, Inputs: {'D_s': 4.000103984269567}, RMSE: 7.612978785028621e-06
Iteration: 36, Inputs: {'D_s': 3.9998969557065753}, RMSE: 8.10800159009572e-06
Iteration: 37, Inputs: {'D_s': 4.000203984269567}, RMSE: 8.00294186716923e-06
Iteration: 38, Inputs: {'D_s': 4.000003984269568}, RMSE: 7.636885044022464e-06
Iteration: 39, Inputs: {'D_s': 4.0001539831065465}, RMSE: 7.759029476603605e-06
Iteration: 40, Inputs: {'D_s': 4.000078984269567}, RMSE: 7.582382508561305e-06
Iteration: 41, Inputs: {'D_s': 4.000128983606386}, RMSE: 7.673335512834914e-06
Iteration: 42, Inputs: {'D_s': 4.000093984269568}, RMSE: 7.599905497846605e-06
Iteration: 43, Inputs: {'D_s': 4.000113984269567}, RMSE: 7.634037532077793e-06
Iteration: 44, Inputs: {'D_s': 4.000108984263288}, RMSE: 7.62299090181715e-06
Iteration: 45, Inputs: {'D_s': 4.0001014842695675}, RMSE: 7.6083619484917106e-06
Iteration: 46, Inputs: {'D_s': 4.000098984269568}, RMSE: 7.604005234620987e-06
Iteration: 47, Inputs: {'D_s': 4.000095448735662}, RMSE: 7.598288839090439e-06
Iteration: 48, Inputs: {'D_s': 4.000090448735662}, RMSE: 7.5949252383595335e-06
Iteration: 49, Inputs: {'D_s': 4.000092948735662}, RMSE: 7.598392648276307e-06
Iteration: 50, Inputs: {'D_s': 4.000097216502251}, RMSE: 7.601081816214098e-06
Iteration: 51, Inputs: {'D_s': 4.000094448735662}, RMSE: 7.600598618435101e-06
Iteration: 52, Inputs: {'D_s': 4.000096332618892}, RMSE: 7.599669013083384e-06
Iteration: 53, Inputs: {'D_s': 4.000095448735662}, RMSE: 7.598288839090439e-06
Now we use the results to plot the fit at the end of the optimization.
_ = results.plot_fit_results()

Cost logger¶
The DataFit
class has an internal “cost-logger” attribute that can be used to log and visualize the cost function during optimization. This is useful for monitoring the progress of the optimization. The cost logger is a dictionary that stores the cost function value at each iteration. The cost logger can be accessed using the cost_logger
attribute of the DataFit
object.
By default, the cost logger tracks the cost function value. DataFit.plot_trace
can be used the plot the progress at the end of the optimization.
objective = iwp.objectives.CurrentDriven(data, options={"model": model})
current_driven = iwp.DataFit(objective, parameters=parameters)
_ = current_driven.run(params_for_pipeline)
_ = current_driven.plot_trace()
[IDAS ERROR] IDASolve
At t = 0 and h = 1e-09, the corrector convergence failed repeatedly or with |h| = hmin.

The cost logger can be changed by passing the cost_logger
argument to the DataFit
object. For example, the following example shows how to pass a cost logger that plots the cost function and parameter values every 10 seconds.
current_driven = iwp.DataFit(
objective,
parameters=parameters,
cost_logger=iwp.data_fits.CostLogger(plot_every=10),
)