{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Using callbacks in objectives\n", "\n", "This notebook explains how to use a callback in an objective function. For details on the `Callback` class, see the [API reference](../../api/data_fits/callbacks.rst\n", "). Potential use cases for this are:\n", "- Plotting some outputs at each iteration of the optimization\n", "- Saving internal variables to plot once the optimization is complete\n", "\n", "Some objectives have \"internal callbacks\" which are not intended to be user facing. These are standard callbacks that can be used to plot the results of an optimization by using `DataFit.plot_fit_results()`. For user-facing callbacks, users should create their own callback objects and call them directly for plotting, as demonstrated in this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Creating a custom callback" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To implement a custom callback, create a class that inherits from `iwp.callbacks.Callback` and calls some specific functions. See the documentation for `iwp.callbacks.Callback` for more information on the available functions and their expected inputs." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import ionworkspipeline as iwp\n", "import matplotlib.pyplot as plt\n", "import pybamm\n", "import numpy as np\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class MyCallback(iwp.callbacks.Callback):\n", " def __init__(self):\n", " super().__init__()\n", " # Implement our own iteration counter\n", " self.iter = 0\n", "\n", " def on_objective_build(self, logs):\n", " self.data_ = logs[\"data\"]\n", "\n", " def on_run_iteration(self, logs):\n", " # Print some information at each iteration\n", " inputs = logs[\"inputs\"]\n", " V_model = logs[\"outputs\"][\"Voltage [V]\"]\n", " V_data = self.data_[\"Voltage [V]\"]\n", "\n", " # calculate RMSE, note this is not necessarily the cost function used in the optimization\n", " rmse = np.sqrt(np.nanmean((V_model - V_data) ** 2))\n", "\n", " print(f\"Iteration: {self.iter}, Inputs: {inputs}, RMSE: {rmse}\")\n", " self.iter += 1\n", "\n", " def on_datafit_finish(self, logs):\n", " self.fit_results_ = logs\n", "\n", " def plot_fit_results(self):\n", " \"\"\"\n", " Plot the fit results.\n", " \"\"\"\n", " data = self.data_\n", " fit = self.fit_results_[\"outputs\"]\n", "\n", " fit_results = {\n", " \"data\": (data[\"Time [s]\"], data[\"Voltage [V]\"]),\n", " \"fit\": (fit[\"Time [s]\"], fit[\"Voltage [V]\"]),\n", " }\n", "\n", " markers = {\"data\": \"o\", \"fit\": \"--\"}\n", " colors = {\"data\": \"k\", \"fit\": \"tab:red\"}\n", " fig, ax = plt.subplots()\n", " for name, (t, V) in fit_results.items():\n", " ax.plot(\n", " t,\n", " V,\n", " markers[name],\n", " label=name,\n", " color=colors[name],\n", " mfc=\"none\",\n", " linewidth=2,\n", " )\n", " ax.grid(alpha=0.5)\n", " ax.set_xlabel(\"Time [s]\")\n", " ax.set_ylabel(\"Voltage [V]\")\n", " ax.legend()\n", "\n", " return fig, ax" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To use this callback, we generate synthetic data for a current-driven experiment and fit a SPM using the `CurrentDriven` objective." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = pybamm.lithium_ion.SPM()\n", "parameter_values = iwp.ParameterValues(\"Chen2020\")\n", "t = np.linspace(0, 3600, 1000)\n", "sim = iwp.Simulation(model, parameter_values=parameter_values, t_eval=t, t_interp=t)\n", "sim.solve()\n", "data = pd.DataFrame(\n", " {x: sim.solution[x].entries for x in [\"Time [s]\", \"Current [A]\", \"Voltage [V]\"]}\n", ")\n", "\n", "# In this example we just fit the diffusivity in the positive electrode\n", "parameters = {\n", " \"Positive particle diffusivity [m2.s-1]\": iwp.Parameter(\"D_s\", initial_value=1e-15),\n", "}\n", "\n", "# Create the callback\n", "callback = MyCallback()\n", "objective = iwp.objectives.CurrentDriven(\n", " data, options={\"model\": model}, callbacks=callback\n", ")\n", "current_driven = iwp.DataFit(objective, parameters=parameters)\n", "\n", "# make sure we're not accidentally initializing with the correct values by passing\n", "# them in\n", "params_for_pipeline = {k: v for k, v in parameter_values.items() if k not in parameters}\n", "\n", "_ = current_driven.run(params_for_pipeline)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we use the callback object we created to plot the results at the end of the optimization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "_ = callback.plot_fit_results()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Cost logger\n", "\n", "The `DataFit` class has an internal \"cost-logger\" attribute that can be used to log and visualize the cost function during optimization. This is useful for monitoring the progress of the optimization. The cost logger is a dictionary that stores the cost function value at each iteration. The cost logger can be accessed using the `cost_logger` attribute of the `DataFit` object." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, the cost logger tracks the cost function value. `DataFit.plot_trace` can be used the plot the progress at the end of the optimization." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "objective = iwp.objectives.CurrentDriven(data, options={\"model\": model})\n", "current_driven = iwp.DataFit(objective, parameters=parameters)\n", "_ = current_driven.run(params_for_pipeline)\n", "_ = current_driven.plot_trace()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The cost logger can be changed by passing the `cost_logger` argument to the `DataFit` object. For example, the following example shows how to pass a cost logger that plots the cost function and parameter values every 10 seconds." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "current_driven = iwp.DataFit(\n", " objective,\n", " parameters=parameters,\n", " cost_logger=iwp.data_fits.CostLogger(plot_every=10),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 2 }