Source code for BondGraphTools.fileio

"""The file save/load interface and file format data model.

This module provides the basic IO functionality such as saving and loading
to file.
"""
# todo:
#    As the file format matures, we probably want to move to a object
#    oriented loader that tracks the (as yet undefined) schema.


import logging
import pathlib
import yaml

from .compound import BondGraph
from .actions import connect, new, expose
from .exceptions import InvalidComponentException

logger = logging.getLogger(__name__)

FILE_VERSION = "0.1"


[docs]def save(model, filename): """Save the model to file. Args: model: The model to be saved filename: The file to save to """ model_directory = _build_model_directory(model) models = {} templates = {} for uri, sub_model in model_directory.items(): models.update({ uri: _build_model_data(sub_model, templates) }) data = { "version": FILE_VERSION, "root": model.name, "models": models } with open(filename, 'w') as filestream: yaml.dump(data, filestream, default_flow_style=False)
def _build_model_directory(model): try: if not model.parent: uri = "/" else: _, uri = model.uri.split(":") directory = {uri: model} for c in model.components: directory.update(_build_model_directory(c)) return directory except AttributeError: return {} def _build_model_data(model, templates): components = [] out = {} for c in model.components: if isinstance(c, BondGraph): _, uri = c.uri.split(":") components.append( f"{c.name} {uri}" ) else: components.append( _build_component_string(c) ) out.update({"components": components}) netlist = [] for tail, head in model.bonds: netlist.append(f"{tail} {head}") if netlist: out.update({"netlist": netlist}) ports = [] for port in model.ports: (c1, e), (c2, f) = model._port_map[port] # noqa if c1 == c2: ports.append(f"{c1.name} {port.name}") else: raise NotImplementedError if ports: out.update({"ports": ports}) return out def _build_component_string(component): out_str = f"{component.name} {component.template}" logger.debug("Trying to serialise: %s", out_str) try: for param, value in component.params.items(): logger.debug("Param: %s, %s", param, value) if isinstance(value, (int, float)): out_str += f" {param}={value}" elif isinstance(value, dict): try: v = value["value"] if isinstance(v, (float, int)): out_str += f" {param}={v}" except KeyError as ex: logger.debug("Skipping: %s ", str(ex)) pass except AttributeError: pass logger.debug("Saving component string: %s", out_str) return out_str
[docs]def load(file_name, model=None, as_name=None): """Load a model from file. Args: file_name (str or Path): The file to load. Returns: An instance of `BondGraph` Raises: `NotImplementedError` for incorrect file version. """ if isinstance(file_name, pathlib.Path): file_name = str(file_name) with open(file_name, 'r') as f: data = yaml.load(f, Loader=yaml.SafeLoader) version = str(data['version']) if version == FILE_VERSION: return _builder(data, model, as_name) else: raise NotImplementedError
def _builder(data, model=None, as_name=None): if not model: root = "/" else: root = model models = data['models'] def _build(model_name, template_name): model_data = models[template_name] netlist = model_data["netlist"] logger.debug("%s: Trying to build", model_name) model = new(name=model_name) for comp_string in model_data["components"]: logger.debug("%s: building", comp_string) try: comp = _base_component_from_str(comp_string) except (ValueError, KeyError): name, sub_model = comp_string.split(" ") comp = _build(name, sub_model) model.add(comp) logger.debug("%s components complete", model_name) _wire(model, netlist) try: io_ports = model_data["ports"] _expose(model, io_ports) except KeyError: logger.debug("No ports on model ") return model def _wire(model, netlist): logger.debug("%s: trying to wire", model.name) def get_port(port_string): tokens = iter(port_string.split('.')) c = next(tokens) try: comp, = (comp for comp in model.components if comp.name == c) except ValueError: raise InvalidComponentException( f"Could not find component {c} in model {model.name}") try: t2 = next(tokens) except StopIteration: return comp try: _ = next(tokens) except StopIteration: logger.debug("Tyring to get port %s, %s", str(comp), str(t2)) try: t2 = int(t2) except ValueError: pass port = comp.get_port(t2) logger.debug("Got %s", str(port)) return port else: raise NotImplementedError for bond_string in netlist: logger.debug("%s: bond %s", model.name, bond_string) tail_str, head_str = bond_string.split() tail = get_port(tail_str) head = get_port(head_str) connect(tail, head) def _expose(model, _io_ports): for port_string in _io_ports: component_name, port_label = port_string.split(" ") comp, = { c for c in model.components if c.name == component_name } expose(comp, port_label) def _parse_build_args(in_args): if not in_args: return [], {} arg = in_args[0] args, kwargs = _parse_build_args(in_args[1:]) try: k, v = arg.split("=") except ValueError: k = None v = arg if v.isnumeric(): try: v = int(v) except ValueError: v = float(v) if not k: args.append(v) else: kwargs.update({k: v}) return args, kwargs def _base_component_from_str(string): label, tempate, *build_args = string.split() args, kwargs = _parse_build_args(build_args) library, component = tempate.split("/") comp = new(name=label, library=library, component=component, value=args) for k, v in kwargs.items(): comp.set_param(k, v) return comp out = _build(root, root) if as_name: out.name = as_name elif not model: out.name = data["root"] return out