Source code for BondGraphTools.view

"""Tools for visualising bond graph.

This module contains temporary tools for producing visualizations of
bond graph network topologies.

"""

import logging

import numpy as np

from scipy.sparse import dok_matrix
from matplotlib.lines import Line2D
from matplotlib.pyplot import rcParams
import networkx as nx

from .exceptions import InvalidComponentException

logger = logging.getLogger(__name__)
FONT = 14
FONT_SM = 10

__all__ = ["draw"]

usetex = rcParams.get("usetex")


[docs]def draw(system): """ Produces a network layout of the system. Args: system: The system to visualise Returns: :obj:`matplotlib.Plot` """ import matplotlib.pyplot as plt fig = plt.figure( figsize=(12, 9), dpi=80 ) plt.ioff() ax = fig.gca() ax.set_aspect("equal") ax.set_title(f"{system.name}") return _draw(system, ax)
def _build_graph(system): try: comp_map = {comp: i for i, comp in enumerate(system.components)} graph = dok_matrix((len(comp_map), len(comp_map)), dtype=int) for (c1, _), (c2, _) in system.bonds: graph[(comp_map[c1], comp_map[c2])] = 1 graph[(comp_map[c2], comp_map[c1])] = 1 except AttributeError as ex: raise InvalidComponentException( "Invalid System: has no components" ) from ex return graph.tocsr(copy=False) def _networkx_layout(graph): nx_graph = nx.Graph(graph) layout = nx.kamada_kawai_layout(nx_graph, scale=20) pos = [(pair[0], pair[1]) for pair in list(layout.values())] return pos class PortGlyph: def __init__(self, ax, string, pos, dir, text_dict): from matplotlib.text import Annotation self.width = 0.1 self.height = 0.1 self.text = Annotation( string, pos, **text_dict ) ax.add_artist(self.text) self.x, self.y = pos if dir == 'top': self.y += self.height / 2 elif dir == 'bottom': self.y -= self.height / 2 elif dir == 'right': self.x += self.width / 2 else: self.x -= self.width / 2 @property def pos(self): return self.x, self.y class Glyph: def __init__(self, node): self._node = node self._axes = None self.x = 0 self.y = 0 self.string = None self.width = 0.1 self.height = 0.1 self._text = None self.ports = { 'top': [], 'right': [], 'bottom': [], 'left': [] } @property def pos(self): return self.x, self.y @pos.setter def pos(self, value): self.x, self.y = value @property def axes(self): return self._axes @axes.setter def axes(self, ax): self._axes = ax from matplotlib.text import Text string = f"${self.string}$" if usetex else self.string self._text = Text( self.x, self.y, string, horizontalalignment='center', verticalalignment='center', size=FONT, usetex=usetex) ax.add_artist(self._text) def add_port(self, string, dir): dx, dy = dir text_dict = { 'size': FONT_SM } if dy > abs(dx): text_dict.update({ 'xytext': (self.x, self.y + self.height / 2), 'horizontalalignment': 'center', 'verticalalignment': 'bottom' } ) dir = 'top' elif -dy > abs(dx): text_dict.update({ 'xytext': (self.x, self.y - self.height / 2), 'horizontalalignment': 'center', 'verticalalignment': 'top' } ) dir = 'bottom' elif dx >= abs(dy): text_dict.update({ 'xytext': (self.x + self.width / 2, self.y), 'horizontalalignment': 'left', 'verticalalignment': 'center' } ) dir = 'right' else: text_dict.update({ 'xytext': (self.x - self.width / 2, self.y), 'horizontalalignment': 'right', 'verticalalignment': 'center' } ) dir = 'left' port = PortGlyph(self.axes, string, self.pos, dir, text_dict) self.ports[dir] = port return port class BondView(Line2D): th = 3 * np.pi / 4 R = np.array(((np.cos(th), -np.sin(th)), (np.sin(th), np.cos(th)))) shortest_bond = None def __init__(self, port_1, port_2, *args, **kwargs): self.port_1 = port_1 self.port_2 = port_2 super().__init__([], [], *args, **kwargs) def calc_lines(self): x1, y1 = self.port_1.pos x2, y2 = self.port_2.pos r1 = max(self.port_1.height, self.port_1.width) r2 = max(self.port_2.height, self.port_2.width) dx = x2 - x1 dy = y2 - y1 x1 += r1 * dx y1 += r1 * dy x2 -= r2 * dx y2 -= r2 * dy lx, ly = x2 - x1, y2 - y1 L = np.sqrt((lx)**2 + (ly)**2) if not self.shortest_bond or self.shortest_bond > L: self.shortest_bond = L lx /= L ly /= L headlength = self.shortest_bond / 5 vect = np.array((lx, ly)) assert abs(np.linalg.norm(vect) - 1) < 0.01 x3, y3 = headlength * self.R.dot(vect) + (x2, y2) self.set_xdata([x1, x2, x3]) self.set_ydata([y1, y2, y3]) def _draw(system, ax, layout=_networkx_layout): graph = _build_graph(system) points = layout(graph) bonds = [] x_min = 0 x_max = 0 y_min = 0 y_max = 0 ax.get_yaxis().set_visible(False) ax.get_xaxis().set_visible(False) views = {} for component, (x, y) in zip(system.components, points): x_min = min(x, x_min) x_max = max(x, x_max) y_min = min(y, y_min) y_max = max(y, y_max) view = Glyph(component) view.pos = (x, y) if component.metamodel not in {'0', '1'}: if usetex: view.string = r"\mathbf{{{t}}}: {n}".format( t=component.metamodel, n=component.name) else: view.string = "{t}: {n}".format( t=component.metamodel, n=component.name) else: if usetex: view.string = r"\mathbf{{{t}}}".format( t=component.metamodel) else: view.string = "{t}".format( t=component.metamodel) view.axes = ax views[component] = view for tail, head in system.bonds: tail_glyph = views[tail.component] head_glyph = views[head.component] try: label_1 = f"[{tail.name}]" except AttributeError: label_1 = "" try: label_2 = f"[{head.name}]" except AttributeError: label_2 = "" dx = head_glyph.x - tail_glyph.x dy = head_glyph.y - tail_glyph.y p1 = tail_glyph.add_port(label_1, (dx, dy)) p2 = head_glyph.add_port(label_2, (-dx, -dy)) bond = BondView(p1, p2) ax.add_artist(bond) bonds.append(bond) for bond in bonds: bond.calc_lines() width = abs(x_max - x_min) height = abs(y_min - y_max) tweak = 0.1 ax.axis([x_min - tweak * width, x_max + tweak * width, y_min - tweak * height, y_max + tweak * height]) def find_renderer(fig): if hasattr(fig.canvas, "get_renderer"): # Some backends, such as TkAgg, have the get_renderer method, which # makes this easy. renderer = fig.canvas.get_renderer() else: # Other backends do not have the get_renderer method, so we have a work # around to find the renderer. Print the figure to a temporary file # object, and then grab the renderer that was used. # (I stole this trick from the matplotlib backend_bases.py # print_figure() method.) import io fig.canvas.print_pdf(io.BytesIO()) renderer = fig._cachedRenderer return(renderer)