import os
import pickle
import warnings
from dataclasses import dataclass, field
import matplotlib.pyplot as plt
import numpy as np
from pandas import DataFrame, MultiIndex, Series
from mpl_toolkits.mplot3d import Axes3D
from pyrates import CircuitTemplate, clear
from matplotlib.collections import LineCollection
from mpl_toolkits.mplot3d.art3d import Line3DCollection
from typing import Union, Any, Optional, List
from .utility import get_solution_keys, get_branch_info, get_solution_variables, \
get_solution_params, get_lyapunov_exponents, parse_point_diagnostics
[docs]@dataclass
class Continuation:
"""One continuation tracked by `ODESystem` — bundles together what was
previously scattered across the `auto_solutions`, `results`, `_results_map`,
and `_branches` dicts.
Fields
------
key
The pycobi key (an int) under which this continuation is stored.
name
User-supplied name passed to `run(..., name=...)`, or `None`.
branch_id
auto-07p's BR identifier for the branch this continuation lives on.
Multiple continuations may share a branch_id (e.g. when one extends
another or when bidirectional merges into the same branch).
icps
Continuation parameter index tuples used to grow this branch. A
bidirectional run that goes both ways in PAR(4) records `[(4,)]`
once; a branch that was extended a second time in PAR(5) appends
`(5,)`.
auto_solution
The auto-07p ``bifDiag`` object holding the actual solution data.
Not pickle-safe — excluded from `to_file`; on `from_file` this
will be `None` for loaded instances.
summary
PyCoBi's parsed `DataFrame` summary of the continuation. Populated
by `_create_summary` after `auto.run` returns.
"""
key: int
name: Optional[str]
branch_id: int
icps: List[tuple] = field(default_factory=list)
auto_solution: Any = None
summary: Optional[DataFrame] = None
# auto-07p reserves PAR(11)..PAR(14) for internal use (period, time, etc.), so
# user parameters must skip that slot range. Read the canonical value from
# PyRates' FortranBackend (the layer that actually allocates PAR slots), with a
# hard-coded fallback for older PyRates that doesn't expose it. Keeping a
# single source of truth so PyCoBi's _var_map and PyRates' parnames stay in
# lockstep.
try:
from pyrates.backend.fortran.fortran_backend import FortranBackend as _FortranBackend
_AUTO_BLOCKED_PAR_RANGE = _FortranBackend._AUTO_BLOCKED_PAR_RANGE
except (ImportError, AttributeError):
_AUTO_BLOCKED_PAR_RANGE = (10, 15)
[docs]class ODESystem:
__slots__ = ["auto_solutions", "results", "_orig_dir", "dir", "_auto", "_last_cont", "_cont_num", "_results_map",
"_branches", "_bifurcation_styles", "_temp", "additional_attributes", "_eq", "_var_map",
"_var_map_inv", "continuations"]
blocked_indices = _AUTO_BLOCKED_PAR_RANGE
def __init__(self, eq_file: str, working_dir: str = None, auto_dir: str = None, init_cont: bool = False,
params: list = None, state_vars: list = None, **kwargs) -> None:
"""
Parameters
----------
eq_file
Equation file that this instance of PyCoBi will use for all calls to `PyCoBi.run`
working_dir
Directory in which all the fortran equation and auto-07p constant files are saved.
auto_dir
Installation directory of auto-07p.
init_cont
If true, an initial-value integration with respect to time is performed at instantiation, using the
equation file provided via the keyword argument `e=<fname>` (a file named `<fname>.f90` should exist in
`working_dir`) and the auto constants provided via the keyword argument `c=<fname>` (a file named
`c.<fname>` should exist in `working_dir`). Defaults to false — many use cases start from a pre-converged
steady state and don't need an IVP, and time integration can be slow or fail to converge for stiff
systems. Set to true to opt into the legacy behaviour.
params
Optional ordered list with names of all parameters in the model equations. Can be used to refer to model
parameters.
state_vars
Optional ordered list that provides a name for each entry in the state vector of the model equations.
kwargs
Additional keyword arguments. When `init_cont=True`, these are forwarded to `ODESystem.run` for the
initial time-integration call. Ignored otherwise.
"""
# make sure that auto-07p environment variables are set
if 'AUTO_DIR' not in os.environ:
if auto_dir is None:
raise ValueError('Auto-07p directory has not been set as environment variable. '
'Please provide path to cmds/auto.env.sh or set environment variable yourself.')
else:
auto_dir = auto_dir.replace('$HOME', '~')
auto_dir = os.path.expanduser(auto_dir)
os.environ['AUTO_DIR'] = auto_dir
path = f"{auto_dir}/cmds:{auto_dir}/bin:{os.environ['PATH']}"
os.environ['PATH'] = path
import auto as a
# open attributes
self.auto_solutions = {}
self.results = {}
# `continuations` is the canonical store; `auto_solutions`, `results`,
# `_results_map`, and `_branches` are kept as mirrors of its fields for
# backward compatibility with external code that reads them directly.
# Mirrors are populated by `_register_continuation` / `_record_summary`;
# direct mirror writes by external code don't propagate back here.
self.continuations = {}
self._orig_dir = os.getcwd()
if working_dir:
try:
os.chdir(working_dir)
except FileNotFoundError:
os.chdir(f"{os.getcwd()}/{working_dir}")
self.dir = os.getcwd()
self.additional_attributes = {}
# private attributes
self._auto = a
self._eq = eq_file
self._last_cont = 0
self._cont_num = 0
self._results_map = {}
self._branches = {}
self._bifurcation_styles = {'LP': {'marker': 'v', 'color' : '#5D6D7E'},
'HB': {'marker': 'o', 'color': '#148F77'},
'CP': {'marker': 'd', 'color': '#5D6D7E'},
'PD': {'marker': 'h', 'color': '#5D6D7E'},
'BT': {'marker': 's', 'color': 'k'},
'GH': {'marker': 'o', 'color': '#148F77'}
}
self._temp = kwargs.pop("template", None)
# Build name <-> PAR/U index maps. Two paths feed into this code:
#
# (1) Hand-written .f90 + c.* without parnames/unames. The user
# passes `params=[...]`, `state_vars=[...]` so PyCoBi can
# translate user-facing names ("eta", "r") to auto-07p's
# internal "PAR(i)" / "U(i)" form for both inputs (ICP, UZR,
# UZSTOP keys) and outputs (DataFrame column relabelling).
#
# (2) PyRates-generated .f90 + c.* with parnames/unames. The user
# passes namespaced names ("p/qif_op/eta") via `from_template`,
# but auto-07p's solution exposes the bare local names ("eta")
# through `solution.coordnames`. Since the namespaced and bare
# names don't collide, every lookup misses and the input/output
# strings pass through unchanged — auto-07p resolves them via
# parnames/unames internally. `_var_map` is dead code on this
# path but doesn't get in the way.
#
# Per-entry storage is a (kind, idx) tuple — kind 'P' for parameters
# (mapped to "PAR(idx)") or 'U' for state variables (mapped to
# "U(idx)"). The "plot" string form is derived on demand by `_map_var`
# rather than stored alongside the int. PAR(14) is auto-07p's reserved
# time slot, always present.
self._var_map = {"t": ("P", 14)}
self._var_map_inv = {}
if params:
increment = 1
for i, key in enumerate(params):
idx = i + increment
if self.blocked_indices[0] <= idx <= self.blocked_indices[1]:
idx -= increment
increment += self.blocked_indices[1] - self.blocked_indices[0]
idx += increment
self._var_map[key] = ("P", idx)
if state_vars:
for i, key in enumerate(state_vars):
self._var_map[key] = ("U", i + 1)
# Only the "plot" string -> user-facing-name direction is ever read
# (by `_create_summary`'s column remapping and `extract`). Derive each
# plot string from the (kind, idx) tuple — `_map_var(name, "plot")`
# produces the same form.
for name in self._var_map:
self._var_map_inv[self._map_var(name, "plot")] = name
# perform initial continuation in time to ensure convergence to steady-state solution
if init_cont:
_ = self.run(ICP=[14], **kwargs)
def __getitem__(self, item):
# Direct int-key lookup, then a named-continuation lookup via
# _results_map. On a double-miss, raise a single KeyError listing every
# name and int key currently registered — much friendlier than the
# opaque `KeyError: <item>` you'd get from the inner dict.
try:
return self.results[item]
except KeyError:
pass
try:
return self.results[self._results_map[item]]
except KeyError:
known_names = sorted(self._results_map.keys())
known_keys = sorted(self.results.keys())
raise KeyError(
f"{item!r} is neither a registered continuation name nor a stored "
f"pyauto key. Known names: {known_names}. Known keys: {known_keys}."
)
@property
def pyrates_template(self):
return self._temp
[docs] def close_session(self, clear_files: bool = False, **kwargs):
if clear_files:
clear(self._temp, **kwargs)
os.chdir(self._orig_dir)
[docs] @staticmethod
def reset_auto_state() -> None:
"""Clear auto-07p's per-process cross-run state (``parnames``, ``unames``).
auto-07p's Python wrapper deliberately leaves the ``parnames`` /
``unames`` entries on its global runner intact between successive
``run()`` calls — see ``auto/runAUTO.py``::
# do not completely replace existing constants data but
# leave the special keys such as unames, parnames, etc, intact
That's helpful when iterating on a single model, but it leaks across
unrelated model loads. Concretely: an `ODESystem` for model A whose
generated c.* declares ``unames = {1: 'r', 2: 'v'}`` populates the
global runner; instantiating model B whose c.* declares no unames will
inherit ``{1: 'r', 2: 'v'}`` and silently relabel B's DataFrame
columns with A's state-variable names. The same applies to
``parnames``.
Call this between unrelated model loads (typically in test teardown,
or right before constructing a fresh `ODESystem` from a different
model). No-ops if `auto` was never imported, or if its internal layout
differs from what we expect.
"""
runner = ODESystem._get_auto_runner()
if runner is None:
return
constants = runner.options.get('constants')
if constants is None:
return
for key in ('parnames', 'unames'):
constants[key] = None
@staticmethod
def _get_auto_runner():
"""Locate auto-07p's global ``runAUTO`` instance via the ``withrunner``
closure that ``AUTOSimpleFunctions`` binds into every command.
Auto's package-level ``run`` / ``load`` / etc. are FunctionType copies
whose globals carry a ``withrunner`` closure pointing at the
AUTOSimpleFunctions singleton; the singleton's ``_runner`` is the
process-wide runAUTO instance that holds the persisted ``constants``
dict. Returns ``None`` if auto isn't imported or the layout changes
upstream — callers must tolerate that.
"""
try:
import auto as a
except ImportError:
return None
run_fn = getattr(a, 'run', None)
if run_fn is None:
return None
withrunner = run_fn.__globals__.get('withrunner')
if withrunner is None or withrunner.__closure__ is None:
return None
for cell in withrunner.__closure__:
simple_funcs = cell.cell_contents
if hasattr(simple_funcs, '_runner'):
return simple_funcs._runner
return None
[docs] @classmethod
def from_yaml(cls, path: str, working_dir: str = None, auto_dir: str = None, init_cont: bool = False,
init_kwargs: dict = None, analytical_jacobian: bool = True,
auto_constants: Union[str, tuple, list] = ('ivp',), **kwargs):
"""Instantiates `ODESystem` from a YAML definition file.
Parameters
----------
path
Full path to a YAML model definition file for a `pyrates.CircuitTemplate`.
working_dir
Directory in which all the fortran equation and auto-07p constant files are saved.
auto_dir
Installation directory of auto-07p.
init_cont
If true, an IVP time integration is run at instantiation against the freshly generated c.ivp file.
Defaults to false — set to true to opt into the legacy behaviour. See `__init__` for details.
init_kwargs
Additional keyword arguments that will be provided to the `ODESystem.run` method for performing the time
integration.
analytical_jacobian
If true (default), instruct PyRates to symbolically differentiate the vector field and emit DFDU/DFDP
inside the generated `func` subroutine; the generated `c.*` file will set `JAC=1` so auto-07p uses the
analytical Jacobian. Set to false to fall back to auto-07p's finite-difference Jacobian (useful when
symbolic differentiation is slow or produces unwieldy expressions for the model at hand).
auto_constants
Name (or iterable of names) of auto-07p continuation scenarios to generate `c.<name>` files for. See
`from_template` for the recognised scenarios and their default constants. Defaults to `('ivp',)` for
backward compatibility.
kwargs
Additional keyword arguments provided to the `pyrates.CircuitTemplate.get_run_func` method that is used to
generate the fortran equation file and the auto constants file that will be used to initialize `ODESystem`.
Returns
-------
ODESystem
`ODESystem` instance.
"""
return cls.from_template(CircuitTemplate.from_yaml(path), working_dir=working_dir, auto_dir=auto_dir,
init_cont=init_cont, init_kwargs=init_kwargs,
analytical_jacobian=analytical_jacobian,
auto_constants=auto_constants, **kwargs)
[docs] @classmethod
def from_template(cls, template: CircuitTemplate, working_dir: str = None, auto_dir: str = None,
init_cont: bool = False, init_kwargs: dict = None, analytical_jacobian: bool = True,
auto_constants: Union[str, tuple, list] = ('ivp',), **kwargs):
"""Instantiates `ODESystem` from a `pyrates.CircuitTemplate`.
Parameters
----------
template
Instance of the class `pyrates.CircuitTemplate`.
working_dir
Directory in which all the fortran equation and auto-07p constant files are saved.
auto_dir
Installation directory of auto-07p.
init_cont
If true, an IVP time integration is run at instantiation against the freshly generated c.ivp file.
Defaults to false — set to true to opt into the legacy behaviour. See `__init__` for details.
init_kwargs
Additional keyword arguments that will be provided to the `ODESystem.run` method for performing the time
integration.
analytical_jacobian
If true (default), instruct PyRates to symbolically differentiate the vector field and emit DFDU/DFDP
inside the generated `func` subroutine; the generated `c.*` file will set `JAC=1` so auto-07p uses the
analytical Jacobian. Set to false to fall back to auto-07p's finite-difference Jacobian (useful when
symbolic differentiation is slow or produces unwieldy expressions for the model at hand). Can be
overridden on a per-continuation basis by passing `JAC=0` or `JAC=1` to `ODESystem.run`.
auto_constants
Name (or iterable of names) of auto-07p continuation scenarios PyRates should emit `c.<name>` files for.
One file is written per requested scenario, each pre-configured with auto-07p constants appropriate for
that mode. Recognised scenarios:
* ``'ivp'`` — initial-value problem / time integration (``IPS=-2``). Required when ``init_cont=True``.
* ``'eq'`` — equilibrium continuation in one parameter (``IPS=1``).
* ``'lc'`` — limit-cycle continuation in one parameter, with PAR(11) as the period (``IPS=2``).
* ``'bvp'`` — boundary-value problem (``IPS=4``).
Pass e.g. ``auto_constants=('ivp', 'eq', 'lc')`` to set up all three at once and then switch scenarios on
a per-call basis via ``ode.run(c='eq', ...)``. Auto-07p constants passed as kwargs (``NMX``, ``DSMAX``,
``UZSTOP``, ...) apply to every requested scenario; per-scenario overrides can be applied at run-time on
the corresponding ``ODESystem.run`` call. Defaults to ``('ivp',)`` for backward compatibility.
kwargs
Additional keyword arguments provided to the `pyrates.CircuitTemplate.get_run_func` method that is used to
generate the fortran equation file and the auto constants file that will be used to initialize `ODESystem`.
Returns
-------
ODESystem
`ODESystem` instance.
"""
# normalise & validate auto_constants (before any I/O so an obviously
# inconsistent combo errors out cleanly rather than as a stale
# working-dir FileNotFoundError from chdir downstream).
scenarios = (auto_constants,) if isinstance(auto_constants, str) else tuple(auto_constants)
if init_cont and 'ivp' not in scenarios:
raise ValueError(
f"init_cont=True performs an IVP integration against c.ivp, but 'ivp' is missing from "
f"auto_constants={scenarios!r}. Either include 'ivp' (e.g. auto_constants=('ivp', 'eq')) or "
f"set init_cont=False."
)
# change working directory
if working_dir:
try:
os.chdir(working_dir)
except FileNotFoundError:
os.chdir(f"{os.getcwd()}/{working_dir}")
# preparations
func_name = kwargs.pop("func_name", "vector_field")
file_name = kwargs.pop("file_name", "system_equations")
dt = kwargs.pop("step_size", 1e-3)
solver = kwargs.pop("solver", "scipy")
if init_kwargs is None:
init_kwargs = {}
# update circuit template variables
if "node_vars" in kwargs:
template.update_var(node_vars=kwargs.pop("node_vars"))
if "edge_vars" in kwargs:
template.update_var(edge_vars=kwargs.pop("edge_vars"))
# generate fortran files
prec = kwargs.pop("float_precision", "float64")
_, _, params, state_vars = template.get_run_func(func_name, dt, file_name=file_name, backend="fortran",
float_precision=prec, auto=True, auto_jac=analytical_jacobian,
auto_constants=scenarios,
vectorize=False, solver=solver, **kwargs)
# PyRates returns the full positional argument list for the run function,
# which prepends some non-parameter args (state vector ``y``, derivative
# ``dy``, time ``t``, optional history function ``hist`` for DDEs) before
# the actual model parameters. Filter by name rather than slicing by
# position — survives DDE models (where ``hist`` shifts the offset) and
# is robust to upstream signature reordering.
non_param_args = {'t', 'y', 'dy', 'hist'}
param_names = tuple(p for p in params if p not in non_param_args)
# initialize ODESystem
return cls(auto_dir=auto_dir, init_cont=init_cont, c="ivp", eq_file=file_name, template=template,
params=param_names, state_vars=list(state_vars), **init_kwargs)
[docs] @classmethod
def from_file(cls, filename: str, auto_dir: str = None):
"""Load `ODESystem` from a pickle file written by `to_file`.
Parameters
----------
filename
Path to the pickle file written by `ODESystem.to_file`.
auto_dir
Installation directory of `auto-07p`.
Returns
-------
ODESystem
ODESystem instance with state restored from the file. Slots not stored on disk
(the live `auto` module, the PyRates `CircuitTemplate`, the cwd snapshots) are
re-initialised by `__init__` rather than from the file.
"""
pyauto_instance = cls('', auto_dir=auto_dir, init_cont=False)
with open(filename, 'rb') as f:
data = pickle.load(f)
for key, val in data.items():
attr = getattr(pyauto_instance, key, None)
# Merge dicts in place to preserve any fresh init values; otherwise
# just bind the loaded value. This makes from_file work for the
# whole-state pickle as well as the results-only one.
if isinstance(attr, dict) and isinstance(val, dict):
attr.update(val)
else:
setattr(pyauto_instance, key, val)
# `continuations` is in _PICKLE_EXCLUDE — rebuild it from the mirror
# dicts we just restored so `get_continuation(...)` works after load.
pyauto_instance._rebuild_continuations_from_mirrors()
return pyauto_instance
# Slots that to_file deliberately omits. ``dir`` / ``_orig_dir`` are
# session-local cwd snapshots that should be re-derived on load.
# ``_auto`` is a live Python module (the auto-07p package) — pickle can't
# serialise modules, and __init__ re-imports it anyway. ``_temp`` is a
# PyRates CircuitTemplate that holds lambdified sympy functions which
# don't pickle; users that want the template on disk should pickle it
# separately or save the YAML path alongside. ``auto_solutions`` contains
# auto-07p bifDiag objects that hold open BufferedReader handles on the
# fort.* files and refuse to pickle. ``continuations`` mirrors hold the
# same bifDiag objects on their `.auto_solution` field — also unpicklable;
# rebuilt by `_rebuild_continuations_from_mirrors` on load. ``_last_cont``
# is normally an int but `merge()` rebinds it to a solution object — same
# problem; skip it and rely on `_cont_num` to track the count.
_PICKLE_EXCLUDE = frozenset({
"dir", "_orig_dir", "_auto", "_temp",
"auto_solutions", "_last_cont", "continuations",
})
[docs] def to_file(self, filename: str, results_only: bool = True, **kwargs) -> None:
"""Save the instance state on disc via pickle.
Parameters
----------
filename
Path to write the pickle file to. If a file already exists at this path it
will be overwritten.
results_only
When true (default), only the PyCoBi-side bookkeeping (`results`, `_branches`,
`_results_map`) is saved — enough to reproduce DataFrames and plots without
rerunning auto-07p. When false, all pickle-safe slots are saved (see
`_PICKLE_EXCLUDE` for the slots intentionally omitted because they can't
round-trip).
kwargs
Extra metadata to attach to the dump. Restored as `additional_attributes`.
Returns
-------
None
"""
if results_only:
data = {'results': self.results, '_branches': self._branches, '_results_map': self._results_map}
else:
data = {key: getattr(self, key)
for key in self.__slots__ if key not in self._PICKLE_EXCLUDE}
data.update({'additional_attributes': kwargs})
try:
with open(filename, 'xb') as f:
pickle.dump(data, f)
except FileExistsError:
with open(filename, 'wb') as f:
pickle.dump(data, f)
[docs] def run(self, origin: Union[int, str, object] = None, starting_point: Union[str, int] = None, variables: list = None,
params: list = None, get_stability: bool = True, get_period: bool = False, get_timeseries: bool = False,
get_eigenvals: bool = False, get_lyapunov_exp: bool = False, reduce_limit_cycle: bool = True,
bidirectional: bool = False, name: str = None, _reverse_direction: bool = False, **auto_kwargs) -> tuple:
"""
Wraps auto-07p command `run` and stores requested solution details on instance.
Parameters
----------
origin
Key of the solution branch that contains the solution `starting_point`, from which the new continuation will
be started.
starting_point
Solution on the origin branch to start the new continuation from. Accepted forms:
* Auto-07p label string — ``'EP'``, ``'LP1'``, ``'HB2'`` etc. The first two characters
are the bifurcation type; the optional trailing integer disambiguates when the branch
carries several solutions of the same type (1-based, defaults to 1). The IVP that
`init_cont=True` runs produces ``'EP1'`` for the initial state and ``'EP2'`` for the
converged steady state — use ``'EP2'`` when starting an equilibrium continuation from
the IVP's terminal state.
* Bare bifurcation type (no number) — ``'EP'`` is equivalent to ``'EP1'``.
* Integer point index — a 1-based auto-07p point number on the branch (the ``PT``
column of the printed table). Useful when auto produced unlabeled regular points
that you want to continue from.
* ``None`` — only valid on the very first call against a fresh `ODESystem` (when
no prior continuation exists to extend); subsequent calls require an explicit
starting point so PyCoBi knows which branch to extend.
variables
Keys of the state variables that should be recorded for each continuation recording step.
params
Keys of the parameters that should be recorded for each continuation recording step.
get_stability
If true, the stability of each solution will be stored in the results under the key 'stability'.
get_period
If true, the period of periodic solutions will be stored in the results under the key 'period'.
get_timeseries
If true, the time vector associated with the state variables of a periodic solution will be stored under the
key 'time'.
get_eigenvals
If true, the eigenvalues (floquet multipliers) or steady-state (periodic) solutions will be stored under the
key 'eigenvalues'.
get_lyapunov_exp
If true, the local lyapunov exponents of solutions will be stored under the key 'lyapunov'.
reduce_limit_cycle
If true, the values of each state variable will be reduced to the minimum and maximum for limit cycle
solutions. Else, the state variable values will be stored for multiple discretized points along the limit
cycle solution (number depends on the arguments passed to Auto).
bidirectional
If true, parameter continuation will be performed into both directions for a given continuation parameter.
name
Name, under which the resulting solution branch will be accessible for future continuations.
_reverse_direction
Private flag set internally by the recursive call that `bidirectional=True` makes. Tells `run()` that
this invocation is the reverse-direction half of a bidirectional continuation, so the result is merged
into the forward branch rather than registered as a fresh continuation. Don't pass manually.
auto_kwargs
Additional keyword arguments to be passed to the auto command `run`. All auto-07p constants can be
overridden here (e.g. `NMX`, `DSMAX`, `ICP`, ...). In particular, `JAC=0`/`JAC=1` overrides the
Jacobian source for this single continuation: pass `JAC=0` to force finite-difference Jacobian even if
`from_template` / `from_yaml` was instantiated with `analytical_jacobian=True`, and vice versa.
Returns
-------
tuple
DataFrame with the results, auto solution branch object.
"""
# auto call
###########
# extract starting point of continuation
if self._last_cont == 0 and self._last_cont not in self.auto_solutions:
auto_kwargs["e"] = self._eq
if 'IRS' in auto_kwargs or 's' in auto_kwargs:
raise ValueError('Usage of keyword arguments `IRS` and `s` is disabled in pycobi. To start from a previous'
'solution, use the `starting_point` keyword argument and provide a tuple of branch '
'number and point number as returned by the `run` method.')
if not starting_point and self._last_cont > 0:
raise ValueError('A starting point is required for further continuation. Either provide a solution to start'
' from via the `starting_point` keyword argument or create a fresh `ODESystem` instance.')
if origin is None:
origin = self._last_cont
elif type(origin) is str:
origin = self._results_map[origin]
elif type(origin) is not int:
origin = origin.pycobi_key
# call to auto
auto_kwargs = self._map_auto_kwargs(auto_kwargs)
solution = self._call_auto(starting_point, origin, **auto_kwargs)
# extract information from auto solution
########################################
# extract branch and solution info
new_branch, new_icp = get_branch_info(solution)
new_points = get_solution_keys(solution)
# get all passed variables and params
solution_tmp, *_ = self.get_solution(point=new_points[0], cont=solution)
if variables is None:
variables = self._get_all_var_keys(solution_tmp)
variables = [self._map_var(v, mode="plot") for v in variables]
if params is None:
try:
params = self._get_all_param_keys(solution_tmp)
except KeyError:
n_params = auto_kwargs['NPAR']
params = [f"PAR({i})" for i in range(1, n_params+1)]
params = [self._map_var(p, mode="plot") for p in params]
# store solution and extracted information in pycobi
####################################################
# Decide whether this continuation extends an existing branch (merge
# path) or starts a fresh one. Three cases:
# 1. Same (branch, origin, icp) tuple seen before — auto extended
# the existing branch; merge into the previous result.
# 2. The reverse-direction half of a bidirectional run — merge into
# the forward branch identified by `_last_cont`.
# 3. Otherwise — allocate a fresh pyauto key.
if new_branch in self._branches and origin in self._branches[new_branch] \
and new_icp in self._branches[new_branch][origin]:
solution_old, *_ = self.get_solution(origin)
pyauto_key = solution_old.pycobi_key
solution, new_points = self.merge(pyauto_key, solution, new_icp)
elif _reverse_direction and 'DS' in auto_kwargs and auto_kwargs['DS'] == '-':
solution_old = self.auto_solutions[self._last_cont]
pyauto_key = solution_old.pycobi_key
solution, new_points = self.merge(pyauto_key, solution, new_icp)
else:
pyauto_key = self._cont_num + 1 if self._cont_num in self.auto_solutions else self._cont_num
solution.pycobi_key = pyauto_key
# The reverse-direction half of a bidirectional run doesn't register
# a fresh name — it merges into the forward branch which is already
# in _results_map under the user's name.
registered_name = name if (name and not _reverse_direction) else None
self._register_continuation(
key=pyauto_key, name=registered_name, branch_id=new_branch,
icp=new_icp, auto_solution=solution,
)
# if continuation should be bidirectional, call this method again with reversed continuation direction
######################################################################################################
if bidirectional:
# perform continuation in opposite direction; the recursive call's
# _reverse_direction=True tells it to merge into this branch rather
# than register itself as a fresh continuation.
ds = auto_kwargs.pop('DS', None)
_, solution = self.run(origin, starting_point, variables=variables, params=params,
get_stability=get_stability, get_period=get_period, get_timeseries=get_timeseries,
get_eigenvals=get_eigenvals, get_lyapunov_exp=get_lyapunov_exp, bidirectional=False,
_reverse_direction=True, DS=1e-3 if ds == '-' else '-', **auto_kwargs)
else:
# store summary of continuation results
if new_icp[0] == 14:
get_stability = False
summary = self._create_summary(solution=solution, points=new_points, variables=variables,
params=params, timeseries=get_timeseries, stability=get_stability,
period=get_period, eigenvals=get_eigenvals, lyapunov_exp=get_lyapunov_exp,
reduce_limit_cycle=reduce_limit_cycle)
self._record_summary(pyauto_key, summary)
return self.results[pyauto_key], solution
[docs] def merge(self, key: int, cont, icp: tuple):
"""Merges two solutions from two separate auto continuations.
Parameters
----------
key
PyCoBi identifier under which the merged solution should be stored. Must be equal to identifier of first
continuation.
cont
auto continuation object that should be merged with the continuation object under `key`.
icp
Continuation parameter that was used in both continuations that are to be merged.
"""
# call merge in auto
solution = self._auto.merge(self.auto_solutions[key] + cont)
solution.pycobi_key = key
# mirror updates (idempotent — `run()` re-syncs through
# `_register_continuation`, but keeping them here lets external
# callers use `merge` without the surrounding bookkeeping)
self.auto_solutions[key] = solution
self._last_cont = key
# also reflect the merged solution on the canonical Continuation
# if one is registered for this key
if key in self.continuations:
self.continuations[key].auto_solution = solution
# extract solution points
points = list(solution.data[0].labels.by_index.keys())
return solution, points
# ------------------------------------------------------------------
# Centralised continuation bookkeeping (replaces the scattered writes
# to auto_solutions / results / _results_map / _branches that used to
# live inline in `run`).
# ------------------------------------------------------------------
def _register_continuation(self, key: int, name: Optional[str], branch_id: int,
icp: tuple, auto_solution: Any) -> "Continuation":
"""Add a new `Continuation` or update an existing one, syncing all
four legacy mirror dicts in the process.
Idempotent on `key`: when the entry already exists (typical merge /
bidirectional-reverse path) the auto_solution is replaced, the icp
appended to both the dataclass and the `_branches` mirror, and any
non-None `name` is set if not already present.
"""
existing = self.continuations.get(key)
if existing is None:
cont = Continuation(
key=key, name=name, branch_id=branch_id,
icps=[icp], auto_solution=auto_solution, summary=None,
)
self.continuations[key] = cont
else:
cont = existing
cont.auto_solution = auto_solution
if icp not in cont.icps:
cont.icps.append(icp)
if name and not cont.name:
cont.name = name
# ---- mirror sync ----
self.auto_solutions[key] = auto_solution
self._last_cont = key
if name:
self._results_map[name] = key
# `_branches` carries an icp list per (branch_id, key) — kept
# append-only (with duplicates allowed) so the merge-detection
# condition in `run()` keeps matching exactly as before.
if branch_id not in self._branches:
self._branches[branch_id] = {key: []}
elif key not in self._branches[branch_id]:
self._branches[branch_id][key] = []
self._branches[branch_id][key].append(icp)
self._cont_num = len(self.auto_solutions)
return cont
def _record_summary(self, key: int, summary: DataFrame) -> None:
"""Attach a parsed summary DataFrame to a Continuation and its
`results` mirror."""
if key in self.continuations:
self.continuations[key].summary = summary
self.results[key] = summary
def _rebuild_continuations_from_mirrors(self) -> None:
"""Reconstruct `self.continuations` from the legacy mirror dicts.
Called by `from_file` after the mirrors have been restored from
disk — `continuations` is in `_PICKLE_EXCLUDE` (because each entry
carries an unpicklable auto_solution), so it has to be rebuilt.
Loaded continuations have `auto_solution=None`; callers wanting
to drive auto from a loaded instance need to re-run the model.
"""
self.continuations.clear()
key_to_name = {key: name for name, key in self._results_map.items()}
key_to_branch: dict = {}
key_to_icps: dict = {}
for branch_id, by_key in self._branches.items():
for key, icps in by_key.items():
key_to_branch[key] = branch_id
key_to_icps.setdefault(key, []).extend(icps)
all_keys = set(self.results) | set(self.auto_solutions) | set(key_to_branch)
for key in all_keys:
# Dedupe icps — `_branches[branch_id][key]` keeps duplicates
# (append-only by design; the merge-detection condition in `run`
# uses `in` on it), but the dataclass surface dedupes so a
# bidirectional run's `icps` stays `[(4,)]` rather than `[(4,), (4,)]`
# both pre- and post-pickle.
seen = []
for icp in key_to_icps.get(key, []):
if icp not in seen:
seen.append(icp)
self.continuations[key] = Continuation(
key=key,
name=key_to_name.get(key),
branch_id=key_to_branch.get(key, 0),
icps=seen,
auto_solution=self.auto_solutions.get(key),
summary=self.results.get(key),
)
[docs] def get_continuation(self, key_or_name: Union[int, str]) -> "Continuation":
"""Return the `Continuation` dataclass for a stored continuation,
looked up by user-supplied name or by pyauto-key int.
Examples
--------
>>> sols, _ = ode.run(starting_point='EP2', name='eta_branch', ICP='eta', ...)
>>> cont = ode.get_continuation('eta_branch')
>>> cont.branch_id, cont.icps, len(cont.summary)
(1, [(4,)], 30)
"""
if isinstance(key_or_name, str):
try:
key = self._results_map[key_or_name]
except KeyError:
raise KeyError(
f"No continuation named {key_or_name!r}; "
f"known names: {sorted(self._results_map)}"
)
else:
key = key_or_name
try:
return self.continuations[key]
except KeyError:
raise KeyError(
f"No continuation with key {key!r}; "
f"known keys: {sorted(self.continuations)}"
)
[docs] def get_summary(self, cont: Optional[Union[Any, str, int]] = None, point=None) -> DataFrame:
"""Extract summary of continuation from PyCoBi.
Parameters
----------
cont
Key of the solution branch.
point
Key of the solution on the branch.
Returns
-------
DataFrame
All recorded state variables, parameters, etc. for the solution/solution branch.
"""
# get continuation summary
if type(cont) is int:
summary = self.results[cont]
elif type(cont) is str:
summary = self.results[self._results_map[cont]]
elif cont is None:
summary = self.results[self._last_cont]
else:
summary = self.results[cont.pycobi_key]
# return continuation or point summary
if not point:
return summary
elif type(point) is str:
n = int(point[2:]) if len(point) > 2 else 1
i = 1
for p in summary.index:
if point[:2] == summary.loc[p, 'bifurcation']:
if i == n:
return summary.loc[p, :]
i += 1
else:
raise KeyError(f'Invalid point: {point} was not found on continuation {cont}.')
return summary.loc[point, :]
[docs] def get_solution(self, cont: Union[Any, str, int], point: Union[str, int] = None) -> Union[Any, tuple]:
"""Extract auto solution object of a given solution/solution branch.
Parameters
----------
cont
Key of the solution branch.
point
Key of the solution on the branch.
Returns
-------
Union[Any, tuple]
Solution type (only if `point` is provided), auto solution object.
"""
# extract continuation object
if type(cont) is int:
cont = self.auto_solutions[cont]
elif type(cont) is str:
cont = self.auto_solutions[self._results_map[cont]]
if point is None:
return cont, None, None
# extract solution point from continuation object and its solution type
try:
# extract solution point via string label
s = cont(point)
solution_name, solution_idx = point[:2], point[2:]
solution_idx = int(solution_idx) if len(solution_idx) > 0 else 0
except (AttributeError, KeyError, TypeError):
# extract solution point via integer index — iterate the branch's
# data entries until one of them holds the requested label.
for bd in cont.data:
try:
if type(point) is int:
s = bd.labels.by_index[point]
solution_name = list(s.keys())[0]
idx = np.argwhere([p == point for p in bd.labels.by_label[solution_name]]).squeeze()
solution_idx = int(idx + 1)
break
else:
s = bd.labels.by_label[point]
solution_name, solution_idx = point[:2], point[2:]
break
except (KeyError, IndexError):
continue
else:
s = None
solution_name = 'No Label'
solution_idx = 0
# make sure a proper solution was extracted, else return an unlabeled solution
if solution_name != 'No Label':
try:
s = s[solution_name]['solution']
except KeyError:
solution_name = 'No Label'
return s, solution_name, solution_idx
def _resolve_summary_key(self, key: str, columns: list) -> str:
"""Resolve a user-supplied key to a column name actually present in the summary.
Resolution order (each step falls through to the next on miss):
1. ``key`` itself is a column — common when the user already passes a
summary-native name (e.g. ``'eta'``, ``'r'``).
2. ``key`` lives in ``_var_map_inv`` — typical when PyCoBi was set up
with explicit ``params=[...]`` / ``state_vars=[...]`` and the user
passes the auto-07p-native form (``'PAR(4)'``, ``'U(1)'``).
3. Strip the namespace prefix from step 2's result and retry — needed
when PyRates' ``parnames`` / ``unames`` emit the bare local name
(``'eta'``) while PyCoBi's ``_var_map_inv`` carries the namespaced
form (``'p/qif_op/eta'``).
4. Bridge through PyRates' uname/parname mapping: look up the user's
key in ``_var_map`` to get the auto-07p slot index, then translate
that slot index to the column name PyRates declared in the c.*
``unames`` / ``parnames`` dict (which auto-07p exposes via its
process-global runner). Critical for multi-node models where
PyRates disambiguates name collisions with ``_v1`` / ``_v2``
suffixes — without this step, all four ``v`` variables on a
three-node Jansen-Rit circuit would silently collapse to the
single bare ``'v'`` column.
5. Strip the namespace prefix from ``key`` itself and retry — used
only when step 4 had no runner state (no ``run()`` call has loaded
a c.* file yet, or auto-07p isn't installed). Covers single-node
PyRates models for which the suffix-disambiguation path doesn't
apply.
Raises a ``KeyError`` that lists what was tried if nothing matches.
"""
if key in columns:
return key
mapped = self._var_map_inv.get(key)
if mapped is not None and mapped in columns:
return mapped
bare_mapped = mapped.rsplit('/', 1)[-1] if isinstance(mapped, str) else None
if bare_mapped is not None and bare_mapped in columns:
return bare_mapped
# Step 4: namespace → slot via `_var_map`, then slot → PyRates-emitted
# column name via the c.* uname/parname dict (read from auto-07p's
# runner). This is the ONLY step that handles multi-node operator
# collisions correctly; the bare-strip in step 5 will silently pick
# the first such collision when several columns share the un-prefixed
# name.
slot_resolved = None
entry = self._var_map.get(key)
if entry is not None:
kind, idx = entry
slot_resolved = self._slot_to_uname_or_parname(kind, idx)
if slot_resolved is not None and slot_resolved in columns:
return slot_resolved
# Step 5: derive the bare name from the user's key directly. The
# fallback when step 4 had no runner state to consult.
bare_key = key.rsplit('/', 1)[-1] if isinstance(key, str) and '/' in key else None
if bare_key is not None and bare_key in columns:
return bare_key
tried = [key]
if mapped is not None:
tried.append(mapped)
if bare_mapped is not None and bare_mapped not in tried:
tried.append(bare_mapped)
if slot_resolved is not None and slot_resolved not in tried:
tried.append(slot_resolved)
if bare_key is not None and bare_key not in tried:
tried.append(bare_key)
raise KeyError(
f"{key!r} is not a recognised summary column. Tried: {tried}. "
f"Available columns: {columns}."
)
def _slot_to_uname_or_parname(self, kind: str, idx: int):
"""Translate an auto-07p slot identifier ``(kind, idx)`` to the column
name PyRates declared for that slot in the c.* file.
``kind`` is ``'U'`` (state variable) or ``'P'`` (parameter); ``idx`` is
the slot index. Returns the corresponding ``unames`` / ``parnames``
value if the global auto-07p runner has loaded a c.* file, else
``None`` (in which case `_resolve_summary_key` falls through to the
bare-strip fallback).
The runner stores these as a list of ``[idx, name]`` pairs (see
``parseC.parseC.__setitem__``); we materialise them into a dict
on demand. No caching: the cost is negligible (a few dozen
entries), and a cache would risk going stale if a subsequent
``run()`` loaded a c.* file with a different ``unames`` / ``parnames``
dict for the same model.
"""
runner = self._get_auto_runner()
if runner is None:
return None
constants = runner.options.get('constants') if runner.options else None
if not constants:
return None
key_name = 'unames' if kind == 'U' else 'parnames'
entries = constants.get(key_name)
if not entries:
return None
mapping = {int(pair[0]): pair[1] for pair in entries if len(pair) >= 2}
return mapping.get(int(idx))
[docs] def plot_continuation(self, x: str, y: str, cont: Union[Any, str, int], ax: plt.Axes = None,
force_axis_lim_update: bool = False, bifurcation_legend: bool = True,
get_stability: bool = True, **kwargs) -> LineCollection:
"""Line plot of 1D/2D parameter continuations and the respective codimension 1/2 bifurcations.
Parameters
----------
x
Key of the parameter/variable plotted on the x-axis.
y
Key of the variable/parameter plotted on the y-axis.
cont
Key of the solution branch to be plotted.
ax
Axis in which to plot the data. If not provided, a new figure will be created.
force_axis_lim_update
If true, the axis limits of x and y axis will be updated after creating the line plots.
bifurcation_legend
If true, a legend will be plotted that lists the type of all special solutions on a continuation curve.
get_stability
If true, the stability of the solutions will be indicated via different line styles.
kwargs
Additional keyword arguments that allow to control the appearance of the line plot.
Returns
-------
LineCollection
Line object that was created.
"""
if ax is None:
fig, ax = plt.subplots()
label_pad = kwargs.pop('labelpad', 5)
tick_pad = kwargs.pop('tickpad', 5)
axislim_pad = kwargs.pop('axislimpad', 0)
# extract information from branch solutions
if x in ["PAR(14)", "t"]:
x = "t"
results, vmap = self.extract([x, y], cont=cont)
results['stability'] = np.asarray([True] * len(results[x]))
results['bifurcation'] = np.asarray(['RG'] * len(results[x]))
elif get_stability:
results, vmap = self.extract([x, y, 'stability', 'bifurcation'], cont=cont)
else:
results, vmap = self.extract([x, y, 'bifurcation'], cont=cont)
results['stability'] = np.asarray([True] * len(results[vmap[x]]))
x, y = vmap[x], vmap[y]
# plot bifurcation points
bifurcation_point_kwargs = ['default_color', 'default_marker', 'default_size', 'custom_bf_styles',
'ignore']
kwargs_tmp = {key: kwargs.pop(key) for key in bifurcation_point_kwargs if key in kwargs}
self.plot_bifurcation_points(solution_types=results['bifurcation'], x_vals=results[x],
y_vals=results[y], ax=ax, **kwargs_tmp)
# set title variable if passed
tvar = kwargs.pop('title_var', None)
if tvar:
tvar_results, tmap = self.extract([tvar], cont=cont)
tval = tvar_results[tmap[tvar]][0]
ax.set_title(f"{tvar} = {tval}")
# plot main continuation
x_data, y_data = results[x], results[y]
line_col = self._get_line_collection(x=x_data.values, y=y_data.values, stability=results['stability'], **kwargs)
ax.add_collection(line_col)
ax.autoscale()
# cosmetics
ax.tick_params(axis='both', which='major', pad=tick_pad)
ax.set_xlabel(x, labelpad=label_pad)
ax.set_ylabel(y, labelpad=label_pad)
self._update_axis_lims(ax, ax_data=[x_data, y_data], padding=axislim_pad, force_update=force_axis_lim_update)
# Skip `ax.legend()` when no labeled artists exist — otherwise
# matplotlib emits "No artists with labels found to put in legend"
# for every continuation that recorded zero bifurcation points.
if bifurcation_legend and ax.get_legend_handles_labels()[1]:
ax.legend()
return line_col
[docs] def plot_trajectory(self, variables: Union[list, tuple], cont: Union[Any, str, int], point: Union[str, int] = None,
ax: plt.Axes = None, force_axis_lim_update: bool = False, cutoff: float = None,
colorbar: bool = False, colorbar_label: str = None, **kwargs
) -> LineCollection:
"""Plot trajectory of state variables through phase space over time.
Parameters
----------
variables
State variables for which to create the trajectory. If 2, a 2D plot will be created, if 3, a 3D plot.
cont
Key of the solution branch to be plotted.
point
Key of the solution on the solution branch for which to plot the trajectories.
ax
Axis in which to plot the data. If not provided, a new figure will be created.
force_axis_lim_update
If true, the axis limits of x and y-axis will be updated after creating the line plots.
cutoff
Initial time to be disregarded for plotting.
colorbar
For 3D plots only: if true, attach a colorbar to the figure showing the
scalar that ``_get_3d_line_collection`` mapped onto the LineCollection's
color (default: the projected x-axis variable). Useful when the
colour gradient already encodes time or a state variable. Ignored
for 2D plots.
colorbar_label
Label for the colorbar. Defaults to the array key
(``'x'`` / ``'y'`` / ``'z'`` — whichever the LineCollection's
``array=`` kwarg points at).
kwargs
Additional keyword arguments that allow to control the appearance of the line plot.
Returns
-------
LineCollection
Line object that was created.
"""
# extract information from branch solutions
try:
results, vmap = self.extract(list(variables) + ['stability'], cont=cont, point=point)
except KeyError:
results, vmap = self.extract(list(variables), cont=cont, point=point)
results['stability'] = None
variables = [vmap[v] for v in variables]
# apply cutoff, if passed
if cutoff:
try:
time, _ = self.extract(['t'], cont=cont, point=point)
time = time['t']
except KeyError:
try:
time, _ = self.extract(['time'], cont=cont, point=point)
time = time['time']
except KeyError:
raise ValueError("Could not find time variable on solution to apply cutoff to. Please consider "
"adding the keyword argument `get_timeseries` to the `PyCoBi.run()` call for which"
"the phase space trajectory should be plotted.")
# `np.where(condition)` returns a 1-element tuple of arrays;
# unpack it so the indexing below sees a flat ndarray. The
# pre-1.0 form (a) passed the raw tuple into pandas Series
# indexing (which interpreted it as a MultiIndex key) and
# (b) assigned per-column slices back into `results`, which
# pandas aligned by index — filling NaN for every row not in
# the slice. Slice the whole container at once instead.
idx = np.where(np.asarray(time) > cutoff)[0]
if hasattr(results, 'iloc'):
results = results.iloc[idx]
else:
# results is a dict (the fallback path's `results['stability']
# = None` branch); slice the array-valued entries by hand.
for key, val in list(results.items()):
if hasattr(val, 'shape') and val.shape:
results[key] = val[idx]
if len(variables) == 2:
# create 2D plot
if ax is None:
fig, ax = plt.subplots()
# plot phase trajectory
line_col = self._get_line_collection(x=results[variables[0]], y=results[variables[1]],
stability=results['stability'], **kwargs)
ax.add_collection(line_col)
ax.autoscale()
# cosmetics
ax.set_xlabel(variables[0])
ax.set_ylabel(variables[1])
elif len(variables) == 3:
# create 3D plot
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
label_pad = kwargs.pop('labelpad', 30)
tick_pad = kwargs.pop('tickpad', 20)
axislim_pad = kwargs.pop('axislimpad', 0.1)
# plot phase trajectory
x, y, z = results[variables[0]], results[variables[1]], results[variables[2]]
# `array=` controls which projected coordinate the LineCollection's
# color map encodes; capture it before _get_3d_line_collection pops
# it from kwargs so we can label the colorbar correctly.
array_key = kwargs.get('array', 'x')
line_col = self._get_3d_line_collection(x=x, y=y, z=z, stability=results['stability'], **kwargs)
ax.add_collection3d(line_col)
ax.autoscale()
# cosmetics
ax.tick_params(axis='both', which='major', pad=tick_pad)
ax.set_xlabel(variables[0], labelpad=label_pad)
ax.set_ylabel(variables[1], labelpad=label_pad)
ax.set_zlabel(variables[2], labelpad=label_pad)
self._update_axis_lims(ax, [x, y, z], padding=axislim_pad, force_update=force_axis_lim_update)
if colorbar:
# Resolve the array-key string to the user-facing variable
# name for a sensible default colorbar label.
array_to_label = {
'x': variables[0], 'y': variables[1], 'z': variables[2],
}
label = colorbar_label or array_to_label.get(array_key, array_key)
ax.figure.colorbar(line_col, ax=ax, label=label, shrink=0.7)
else:
raise ValueError('Invalid number of state variables to plot. First argument can only take 2 or 3 state'
'variable names as input.')
return line_col
[docs] def plot_timeseries(self, var: str, cont: Union[Any, str, int], points: list = None, ax: plt.Axes = None,
linespecs: list = None, **kwargs) -> plt.Axes:
"""Plot state variable of a periodic solution over time.
Parameters
----------
var
Key of the state variable.
cont
Key of the solution branch.
points
List with keys of the solutions for which to create time series plots.
When ``None`` (default), every labelled point on the continuation
is plotted as a separate trace.
ax
Axis in which to plot the data. If not provided, a new figure will be created.
linespecs
Per-trace keyword overrides; ``linespecs[i]`` is merged into
``kwargs`` for the i-th point.
kwargs
Additional keyword arguments that control the appearance of the plot.
Returns
-------
plt.Axes
Axis object that contains the plotted timeseries.
"""
# Resolve `points`. The pre-1.0 `points=None` branch built an N-row
# results dict but then iterated `range(len(['RG']))`, silently
# dropping all but the first trace; worse, an unpacking-only-works-
# for-2-elements bug broke the path entirely. Replaced with an
# explicit enumeration of every labelled point on the continuation.
if not points:
cont_key = self._results_map[cont] if isinstance(cont, str) else cont
try:
stored = self.results[cont_key]
except KeyError as exc:
raise KeyError(
f"plot_timeseries: continuation {cont!r} not found "
f"in self.results"
) from exc
points = list(stored.keys())
if not points:
raise ValueError(
f"plot_timeseries: continuation {cont!r} has no recorded "
f"points to plot. Pass `points=[...]` explicitly or rerun "
f"the continuation with NPR set so points are labelled."
)
# extract information from branch solutions
results = []
vmap: dict = {}
for p in points:
r, vmap = self.extract([var, 'time'], cont=cont, point=p)
results.append(r)
var_col = vmap[var]
time_col = vmap['time']
# create plot
if ax is None:
_, ax = plt.subplots()
# plot phase trajectory
if not linespecs:
linespecs = [dict() for _ in range(len(points))]
def _unwrap_per_period(value):
"""Pull out the per-period ndarray when pandas wrapped it in a
length-1 Series. The ``('time', '')`` column on an LC summary
stores one ndarray per labelled point (object-dtype), so
partial-indexing the row Series by ``'time'`` returns a Series
of length 1 containing the array rather than the array itself."""
if hasattr(value, 'iloc') and hasattr(value, '__len__') and len(value) == 1:
value = value.iloc[0]
return np.atleast_1d(value).squeeze()
for i in range(len(points)):
time = _unwrap_per_period(results[i][time_col])
y = _unwrap_per_period(results[i][var_col])
kwargs_tmp = dict(kwargs)
kwargs_tmp.update(linespecs[i])
line_col = self._get_line_collection(x=time, y=y, **kwargs_tmp)
ax.add_collection(line_col)
ax.autoscale()
ax.legend([str(p) for p in points])
return ax
[docs] def plot_bifurcation_points(self, solution_types: DataFrame, x_vals: DataFrame, y_vals: DataFrame, ax: plt.Axes,
default_color: str = 'k', default_marker: str = '*', default_size: float = 10,
ignore: list = None, custom_bf_styles: dict = None) -> tuple:
"""Plot markers for special solutions at coordinates in 2D space.
Parameters
----------
solution_types
Type of each solution, entries of DataFrame should be strings.
x_vals
X-coordinates of each solution.
y_vals
Y-coordinates of each special solution
ax
Axis in which to plot the data. If not provided, a new figure will be created.
default_color
Default color to be used if bifurcation style is not known.
default_marker
Default marker style to be used if bifurcation style is not known.
default_size
Default marker size.
ignore
List of solution types that should not be displayed.
custom_bf_styles
Dictionary containing adjustments to the default bifurcation markers and colors.
Returns
-------
tuple
A 2-entry tuple of (1) a list of PathCollections that correspond to bifurcation points, and (2) a list of
corresponding bifurcation types.
"""
if not ignore:
ignore = []
# set bifurcation styles
if custom_bf_styles:
for key, args in custom_bf_styles.items():
self.update_bifurcation_style(key, **args)
bf_styles = self._bifurcation_styles.copy()
# draw bifurcation points. Pre-1.0 used `plt.sca(ax) + plt.plot`
# which silently couples to matplotlib's global "current axes" state
# — calling this from a function that itself activates a different
# axes would draw on the wrong figure. Routed through `ax.plot`
# directly.
points, labels = ax.get_legend_handles_labels()
for bf, x, y in zip(solution_types.values, x_vals.values, y_vals.values):
if bf not in "EPMXRG" and bf not in ignore:
if bf in bf_styles:
m = bf_styles[bf]['marker']
c = bf_styles[bf]['color']
else:
m = default_marker
c = default_color
# Limit-cycle case: y is a (min, max) pair — draw a marker
# on each envelope. Equilibrium case: y is scalar.
if y.shape and np.sum(y.shape) > 1:
if bf not in labels:
line = ax.plot(x, y.max(), markersize=default_size, marker=m, c=c, label=bf)
points.append(line[0])
labels.append(bf)
else:
ax.plot(x, y.max(), markersize=default_size, marker=m, c=c)
ax.plot(x, y.min(), markersize=default_size, marker=m, c=c)
else:
if bf not in labels:
line = ax.plot(x, y, markersize=default_size, marker=m, c=c, label=bf)
points.append(line[0])
labels.append(bf)
else:
ax.plot(x, y, markersize=default_size, marker=m, c=c)
return points, labels
[docs] def plot_continuation_grid(self, plots: list, ncols: int = 2, figsize: tuple = None,
sharex: bool = False, sharey: bool = False,
**shared_kwargs) -> tuple:
"""Lay out multiple 1D/2D continuations as a grid of subplots.
Convenience helper to compare continuations side-by-side (e.g. a
codim-1 scan in eta next to its codim-2 fold curve in (eta, Delta),
or several "same x/y, different parameter setting" diagrams).
Parameters
----------
plots
List of dicts, one per subplot. Each dict must contain
``'x'``, ``'y'``, and ``'cont'`` (forwarded to
:meth:`plot_continuation` as the corresponding positional
args). Any additional keys override ``shared_kwargs`` for that
specific subplot — except ``'title'``, which is set on the
subplot's axes via ``ax.set_title``. The optional
``'panel_label'`` key is drawn in the upper-left corner of
the subplot (useful for figure-quality (a), (b), (c)
annotations).
ncols
Number of columns in the grid. Rows are derived from
``ceil(len(plots) / ncols)``. Default 2.
figsize
Forwarded to :func:`matplotlib.pyplot.subplots`. Defaults to
``(5 * ncols, 4 * nrows)`` — i.e. each subplot gets ~5x4 inches.
sharex, sharey
Forwarded to :func:`matplotlib.pyplot.subplots`. Useful when
all panels live on the same parameter range.
shared_kwargs
Keyword arguments applied to every subplot (e.g.
``bifurcation_legend=False`` to suppress the per-panel
legend). Per-plot keys override these.
Returns
-------
tuple
``(fig, axes, line_cols)``. ``axes`` is the flat list of
``Axes`` (length ``len(plots)``; trailing positions in the
grid that have no plot are deleted). ``line_cols`` is the
list of LineCollections returned by each ``plot_continuation``
call, in the same order as ``plots``.
"""
if not plots:
raise ValueError("plot_continuation_grid requires at least one plot spec")
n = len(plots)
nrows = (n + ncols - 1) // ncols # ceil(n / ncols)
if figsize is None:
figsize = (5 * ncols, 4 * nrows)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize,
sharex=sharex, sharey=sharey, squeeze=False)
axes_flat = list(axes.flatten())
line_cols = []
for i, spec in enumerate(plots):
spec = dict(spec)
title = spec.pop('title', None)
panel_label = spec.pop('panel_label', None)
x = spec.pop('x')
y = spec.pop('y')
cont = spec.pop('cont')
# per-plot spec overrides the shared defaults
kwargs_tmp = dict(shared_kwargs)
kwargs_tmp.update(spec)
ax = axes_flat[i]
line_col = self.plot_continuation(x=x, y=y, cont=cont, ax=ax, **kwargs_tmp)
line_cols.append(line_col)
if title is not None:
ax.set_title(title)
if panel_label is not None:
# axes-fraction text in the upper-left, bold, slightly
# larger than the tick labels — matches the convention
# used in most published bifurcation figures.
ax.text(0.02, 0.96, panel_label, transform=ax.transAxes,
ha='left', va='top', fontweight='bold', fontsize='large')
# Hide any trailing subplot positions we didn't use.
for j in range(n, len(axes_flat)):
fig.delaxes(axes_flat[j])
fig.tight_layout()
return fig, axes_flat[:n], line_cols
[docs] def update_bifurcation_style(self, bf_type: str, marker: str = None, color: str = None) -> None:
"""Update the default marker and color of a given special solution type.
Parameters
----------
bf_type
Type of the special solution.
marker
New marker type.
color
New color.
Returns
-------
None
"""
if bf_type in self._bifurcation_styles:
if marker:
self._bifurcation_styles[bf_type]['marker'] = marker
if color:
self._bifurcation_styles[bf_type]['color'] = color
else:
if marker is None:
marker = 'o'
if color is None:
color = 'k'
self._bifurcation_styles.update({bf_type: {'marker': marker, 'color': color}})
def _create_summary(self, solution: Union[Any, dict], points: list, variables: list, params: list,
timeseries: bool, stability: bool, period: bool, eigenvals: bool, lyapunov_exp: bool,
reduce_limit_cycle: bool
) -> DataFrame:
"""Creates summary of auto continuation and stores it in dictionary.
Builds a single dict-of-lists keyed by `(column_name, sub_index)` tuples
(scalar columns use `''` for the sub-index, matching the legacy shape
produced by merging a flat-column DataFrame into a MultiIndex one).
One DataFrame construction at the end replaces the previous two-build
+ column-by-column merge dance.
Parameters
----------
solution
points
variables
params
timeseries
stability
period
eigenvals
lyapunov_exp
reduce_limit_cycle
Returns
-------
DataFrame
Continuation summary with a `MultiIndex` columns axis; vector-valued
quantities (state-var min/max, eigenvalues, lyapunov exponents) use
integer sub-indices, scalar quantities use empty-string sub-indices.
"""
# ``col_values`` is the single mutable structure built in the loop.
# Keys are the (name, sub_index) tuples that go straight into the
# final MultiIndex; values are per-row lists. dict insertion order
# determines the final column order.
col_values: dict = {}
indices: list = []
for point in points:
s, solution_type, solution_idx = self.get_solution(cont=solution, point=point)
if solution_type == 'No Label' or solution_type == 'MX':
continue
indices.append(point)
var_vals = get_solution_variables(s, variables, timeseries)
param_vals = get_solution_params(s, params)
# Parse the diagnostic block at most once per point; reused for
# stability + eigenvalues + lyapunov below.
diag = parse_point_diagnostics(s) if (stability or eigenvals or lyapunov_exp) else None
period_val = (get_solution_params(s, ['PAR(11)'])[0]
if (period or lyapunov_exp or eigenvals) else None)
# Column insertion order is chosen to match the legacy output
# column order produced by the old data_2d-first / data_1d-then-
# merged build, so existing user scripts that rely on positional
# access keep working.
# --- state-variable values (vector for limit cycles) ---
for var, val in zip(variables, var_vals):
if len(val) > 1 and reduce_limit_cycle:
col_values.setdefault((var, 0), []).append(np.min(val))
col_values.setdefault((var, 1), []).append(np.max(val))
else:
for i, v in enumerate(val):
col_values.setdefault((var, i), []).append(v)
# --- eigenvalues / Floquet multipliers ---
if eigenvals:
for i, v in enumerate(diag['eigenvalues']):
col_values.setdefault(('eigenvalues', i), []).append(v)
# --- lyapunov exponents ---
if lyapunov_exp:
for i, lyap in enumerate(get_lyapunov_exponents(diag['eigenvalues'], period_val)):
col_values.setdefault(('lyapunov_exponents', i), []).append(lyap)
# --- bifurcation type / index ---
col_values.setdefault(('bifurcation', ''), []).append(solution_type)
col_values.setdefault(('bifurcation_index', ''), []).append(solution_idx)
# --- parameter values ---
for param, val in zip(params, param_vals):
col_values.setdefault((param, ''), []).append(val)
# --- time vector (when get_timeseries=True) ---
if len(var_vals) > len(variables) and timeseries:
col_values.setdefault(('time', ''), []).append(var_vals[-1])
if stability:
col_values.setdefault(('stability', ''), []).append(bool(diag['stable']))
if period:
col_values.setdefault(('period', ''), []).append(period_val)
# ---- single DataFrame construction ----
if not col_values:
return DataFrame(index=indices)
# Apply _var_map_inv remapping to translate "PAR(i)" / "U(i)" column
# names back to user-facing names where one is registered. Done once,
# over the discovered columns, rather than per-point.
remapped = [
(self._var_map_inv[name] if name in self._var_map_inv else name, sub)
for (name, sub) in col_values
]
columns = MultiIndex.from_tuples(remapped)
# `index` may have one extra element when the last `points` entry was
# dropped mid-row (legacy `_to_dataframe` fallback handled this by
# trimming `data[:-1]`). Here every dropped point is filtered out at
# the top of the loop so `indices` and the per-column lists agree by
# construction.
return DataFrame(dict(zip(columns, col_values.values())), index=indices)
def _call_auto(self, starting_point: Union[str, int], origin: Union[Any, dict], **auto_kwargs) -> Any:
if starting_point:
s, solution_name, _ = self.get_solution(point=starting_point, cont=origin)
if solution_name == "No Label":
raise KeyError(f"Starting point {starting_point} could not be found on the provided origin branch.")
solution = self._auto.run(s, **auto_kwargs)
else:
solution = self._auto.run(**auto_kwargs)
return self._start_from_solution(solution)
def _update_axis_lims(self, ax: Union[plt.Axes, Axes3D], ax_data: list, padding: float = 0.,
force_update: bool = False) -> None:
ax_names = ['x', 'y', 'z']
for i, data in enumerate(ax_data):
axis_limits = self._get_axis_lims(np.asarray(data), padding=padding)
if force_update:
min_val, max_val = axis_limits
else:
min_val, max_val = eval(f"ax.get_{ax_names[i]}lim()")
min_val, max_val = np.min([min_val, axis_limits[0]]), np.max([max_val, axis_limits[1]])
eval(f"ax.set_{ax_names[i]}lim(min_val, max_val)")
def _map_auto_kwargs(self, kwargs: dict) -> dict:
# handle the continuation parameter
if "ICP" in kwargs:
val = kwargs.pop("ICP")
if type(val) is str:
kwargs["ICP"] = self._map_var(val)
elif type(val) in [list, tuple]:
kwargs["ICP"] = [self._map_var(v) if type(v) is str else v for v in val]
else:
kwargs["ICP"] = val
# handle PAR-keyed dict constants (named -> integer index). On the
# PyRates-generated path auto-07p resolves names itself via parnames,
# so unmapped strings pass through harmlessly; on the hand-written
# path the explicit translation here is what makes string keys work.
for key in ("UZR", "UZSTOP", "THL", "THU"):
if key in kwargs:
d = kwargs.pop(key)
kwargs[key] = {self._map_var(k) if type(k) is str else k: v for k, v in d.items()}
return kwargs
def _map_var(self, var: str, mode: str = "cont"):
"""Translate a user-facing var name to auto-07p's internal form.
With ``mode="cont"`` returns the integer index (PAR slot for
parameters, U slot for state variables) for use in ICP / UZR / UZSTOP
/ THL / THU. With ``mode="plot"`` returns the ``"PAR(i)"`` or
``"U(i)"`` string used as a DataFrame column key. Unknown names pass
through unchanged so non-PyCoBi-managed keys (raw ``PAR(i)`` strings,
bare ints, etc.) still work.
"""
entry = self._var_map.get(var)
if entry is None:
return var
kind, idx = entry
if mode == "cont":
return idx
# "plot" mode (the only other mode currently used)
return f"PAR({idx})" if kind == "P" else f"U({idx})"
@staticmethod
def _get_all_var_keys(solution):
# Prefer the solution's own coord names — when PyRates emits `unames`
# into the auto-07p c.* file (post-pyrates>=1.1), auto exposes state
# variables under their user-facing name rather than as ``U(i)``.
# Fall back to the historical ``U(i)`` form for hand-written fortran
# systems or older PyRates without unames.
coords = getattr(solution, 'coordnames', None)
if coords:
return list(coords)
return [f'U({i+1})' for i in range(solution['NDIM'])]
@staticmethod
def _get_all_param_keys(solution):
return solution.PAR.coordnames
def _start_from_solution(self, solution: Any) -> Any:
"""Auto-retry hook for runs that produced only a starting direction.
When auto-07p's first run returns just a starting-direction diagnostic
with a single ``EP`` label (no continuation steps taken), this method
re-invokes ``auto.run`` from that EP to actually kick off the
continuation. Emits a UserWarning so callers know a second auto.run
fired — the second call uses no kwargs from the original, so a user
relying on specific NMX / DS / DSMAX overrides may want to know.
The retry also pops the EP from the original solution's labels dict
(auto's API quirk: that popitem is how we extract the seed solution),
leaving the original solution object structurally modified.
"""
diag = str(solution[0].diagnostics)
sol_keys = get_solution_keys(solution)
if 'Starting direction of the free parameter(s)' in diag and len(sol_keys) == 1 and \
"EP" in list(solution[0].labels.by_index[sol_keys[0]])[0]:
warnings.warn(
"auto-07p's first run took no continuation steps (only a starting-direction "
"diagnostic was produced); restarting auto.run from the single EP label without "
"the original run's keyword arguments. If this is unexpected, check that your "
"auto constants (DS, DSMAX, ICP, ...) are appropriate for the model.",
UserWarning,
stacklevel=3,
)
_, s = solution[0].labels.by_index.popitem()
solution = self._auto.run(s['EP']['solution'])
return solution
@staticmethod
def _get_line_collection(x, y, stability=None, line_style_stable='solid', line_style_unstable='dotted',
line_color_stable='k', line_color_unstable='gray', **kwargs) -> LineCollection:
"""
Parameters
----------
x
y
stability
line_style_stable
line_style_unstable
line_color_stable
line_color_unstable
kwargs
Returns
-------
LineCollection
"""
# Combine x and y into segment-friendly (N, 2) arrays. The squeeze-
# then-take-shape[0] pattern raises IndexError on degenerate length-1
# inputs (the squeezed array is 0-D); a `reshape(-1, 1)` fallback
# handles both 1-D-length-1 and N-D-length-1 cases without crashing
# the line-collection construction. The pre-1.0 code had this guard
# on `x` but not on `y`.
#
# Strip pandas Series wrappers so `y[0]` / `y[i]` below are
# positional regardless of the Series' index. After a `cutoff`
# slice in `plot_trajectory` the Series is reindexed and label-
# based access via `y[0]` would raise `KeyError: 0`.
if hasattr(y, 'values'):
y = y.values
x = np.asarray(x).reshape(-1, 1)
if hasattr(y[0], "shape") and sum(y[0].shape) > 1:
y = np.asarray([y[i] for i in range(y.shape[0])])
y_max = np.reshape(y.max(axis=1), (y.shape[0], 1))
y_min = np.reshape(y.min(axis=1), (y.shape[0], 1))
y_min = np.append(x, y_min, axis=1)
y = y_max
add_min = True
else:
y = np.asarray(y).reshape(-1, 1)
add_min = False
y = np.append(x, y, axis=1)
# if stability was passed, collect indices for stable line segments
###################################################################
# The size>1 guard skips the segmentation logic for degenerate
# 1-element stability arrays (which would yield a single empty
# segment via the diff). The pre-1.0 form used
# `np.sum(stability.shape) > 1` — semantically equivalent for 1-D
# arrays but cryptic; switched to the explicit `.size > 1`.
# The `dtype='int'` coerce can raise TypeError if the stability
# array contains None values (e.g. an IVP continuation that
# didn't record stability per time-step); fall through to the
# no-stability branch in that case.
stab_arr = None
if stability is not None and np.asarray(stability).size > 1:
try:
stab_arr = np.asarray(stability, dtype='int')
except TypeError:
stab_arr = None
if stab_arr is not None:
# collect indices
stability = stab_arr
stability_changes = np.concatenate([np.zeros((1,)), np.diff(stability)])
idx_changes = np.sort(np.argwhere(stability_changes != 0))
idx_changes = np.append(idx_changes, len(stability_changes))
# create line segments
lines, styles, colors = [], [], []
idx_old = 1
for idx in idx_changes:
lines.append(y[idx_old-1:idx, :])
styles.append(line_style_stable if stability[idx_old] else line_style_unstable)
colors.append(line_color_stable if stability[idx_old] else line_color_unstable)
if add_min:
lines.append(y_min[idx_old - 1:idx, :])
styles.append(line_style_stable if stability[idx_old] else line_style_unstable)
colors.append(line_color_stable if stability[idx_old] else line_color_unstable)
idx_old = idx
else:
lines = [y, y_min] if add_min else [y]
styles = [line_style_stable, line_style_stable] if add_min else [line_style_stable]
colors = [line_color_stable, line_color_stable] if add_min else [line_color_stable]
# Pop the two LineCollection kwargs we set explicitly so a user
# passing `colors=` or `linestyles=` through one of the plot_* helpers
# overrides the computed values rather than colliding with them
# (passing both via `**kwargs` raises TypeError).
colors = kwargs.pop('colors', colors)
styles = kwargs.pop('linestyles', styles)
return LineCollection(segments=lines, linestyles=styles, colors=colors, **kwargs)
@staticmethod
def _get_3d_line_collection(x, y, z, stability=None, line_style_stable='solid', line_style_unstable='dotted',
**kwargs) -> Line3DCollection:
"""
Parameters
----------
x
y
z
stability
line_style_stable
line_style_unstable
kwargs
Returns
-------
Line3DCollection
"""
# combine y and param vals
# Coerce to float upfront. Pandas Series that came out of an LC
# summary row (rows mix scalar and ndarray columns) have ``object``
# dtype, which propagates through `np.reshape` and then trips
# `set_array` later when matplotlib tries to apply the colormap.
x = np.asarray(x, dtype=float).reshape(-1, 1)
y = np.asarray(y, dtype=float).reshape(-1, 1)
z = np.asarray(z, dtype=float).reshape(-1, 1)
y = np.append(x, y, axis=1)
y = np.append(y, z, axis=1)
# if stability was passed, collect indices for stable line segments
###################################################################
# Same coerce-tolerant pattern as `_get_line_collection`: skip
# stability segmentation if the array contains None values.
stab_arr = None
if stability is not None and np.asarray(stability).size > 1:
try:
stab_arr = np.asarray(stability, dtype='int')
except TypeError:
stab_arr = None
if stab_arr is not None:
# collect indices
stability = stab_arr
stability_changes = np.concatenate([np.zeros((1,)), np.diff(stability)])
idx_changes = np.sort(np.argwhere(stability_changes != 0))
idx_changes = np.append(idx_changes, len(stability_changes))
# create line segments
lines, styles = [], []
idx_old = 1
for idx in idx_changes:
lines.append(y[idx_old - 1:idx, :])
styles.append(line_style_stable if stability[idx_old] else line_style_unstable)
idx_old = idx
else:
lines = [y]
styles = [line_style_stable]
# create line collection
array = kwargs.pop('array', 'x')
# Same as `_get_line_collection`: pop `linestyles` so a user-supplied
# value overrides the per-stability-block styles rather than colliding
# with the explicit kwarg below.
styles = kwargs.pop('linestyles', styles)
# NOTE: Line3DCollection takes `lines` as a positional argument (not
# `segments=` like the 2D LineCollection). Passing `segments=lines`
# used to silently fail on older matplotlib and now raises outright.
line_col = Line3DCollection(lines, linestyles=styles, **kwargs)
# post-processing
if array == 'x':
array = x.squeeze()
elif array == 'y':
array = y[:, 1].squeeze()
elif array == 'z':
array = z.squeeze()
line_col.set_array(array)
return line_col
@staticmethod
def _get_axis_lims(x: np.array, padding: float = 0.) -> tuple:
x_min, x_max = x.min(), x.max()
x_pad = (x_max - x_min) * padding
return x_min - x_pad, x_max + x_pad
def _extract_merge_point(p: int, df: DataFrame) -> Series:
p_tmp = df.loc[p, :]
if len(p_tmp.shape) > 1 and p_tmp.shape[0] > 1:
return p_tmp.iloc[0, :]
return p_tmp