import numpy as np
import binascii
import importlib
import inspect
import os
import re
import numpy as np
import textwrap
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Union
MATCH_REWARD = 100
COVERAGE_REWARD = 10
OBLIGATION_REWARD_FACTOR = 10
# we assume "observation" and "action" are already one-hot encoded
observation = Union[int, np.ndarray]
action = Union[int, np.ndarray]
label = str
pattern_trajectory = str
trajectory = List[Tuple[observation, action]]
labeled_trajectory = Union[List[Tuple[label, label]], List[label], str]
trajectory_pattern_str = str
trajectory_str = str
[docs]@dataclass
class cardinality:
lower_bound: Union[int, str]
upper_bound: Union[int, str]
def __str__(self) -> str:
return f'({self.lower_bound},{self.upper_bound})'
def __repr__(self) -> str:
return f'({self.lower_bound},{self.upper_bound})'
[docs] @staticmethod
def from_dict(d: Dict[str, Any]) -> 'cardinality':
return cardinality(
lower_bound=d['lower_bound'],
upper_bound=d['upper_bound']
)
[docs]def replace_all_character(string, old_character, new_character):
while True:
_string = string.replace(old_character, new_character)
if string == _string:
break
string = _string
return string
[docs]def generate_random_hash(num_bytes=16):
# Génère des octets aléatoires
random_bytes = os.urandom(num_bytes)
# Convertit ces octets en chaîne hexadécimale
random_hash = binascii.hexlify(random_bytes).decode('utf-8')
return random_hash
[docs]def handle_lambda(source: str) -> str:
source = textwrap.dedent(source)
# check the function's source is a lambda function
if 'lambda' in source and not "def" in source:
_source = replace_all_character(source, "\n", " ")
_source = replace_all_character(source, " ", " ")
_source += "\n"
match = re.search(
r"^.*(lambda [\w\s,]+:\s*\[\(.+?\s*\)\])[\n,]|^.*(lambda [\w\s,]+:.+?)[\n,]", _source, re.MULTILINE)
if match:
function_name = None
source = [m for m in match.groups() if m is not None][0]
# check if it already has a name
match = re.search(
r"^\s*(.*=.*lambda [\w\s,]+:\s*\[\(.+?\s*\)\])[\n,]|^\s*(.*=.*lambda [\w\s,]+:.+?)[\n,]", _source, re.MULTILINE)
if match:
source = [m for m in match.groups() if m is not None][0]
function_name, source = source.split("=")
if ":" in function_name:
function_name = function_name.split(":")[0]
function_name = replace_all_character(
function_name, " ", "")
function_args, function_body = source.split(":")
function_args = replace_all_character(
function_args.replace("lambda", ""), " ", "")
if function_name is None:
function_name = f"lamba_{generate_random_hash()}"
source = f'def {function_name}({function_args}):\n return {function_body}'
return source
[docs]def load_function(function_data: Any) -> Tuple[Callable, bool]:
if not 'module_name' in function_data:
raise Exception("Module should be given")
module_name = function_data['module_name']
module = importlib.import_module(module_name)
if 'source' in function_data:
source = handle_lambda(function_data['source'])
match = re.search(
r"^\s*def\s+([a-zA-Z_]\w*)\s*\(", source, re.MULTILINE)
if match:
function_name = match.group(1)
lcs = {}
exec(source, module.__dict__, lcs)
_function = lcs.get(function_name)
return _function, True, function_name, source
elif 'function_name' in function_data:
function_name = function_data['function_name']
_function = getattr(module, function_name)
return _function, False, None, None
[docs]def dump_function(function: Callable, save_source: bool = False, function_name=None, function_source=None):
if save_source:
try:
source = textwrap.dedent(inspect.getsource(function))
except Exception as e:
source = function_source
# check the function's source is a lambda function
source = handle_lambda(source)
return {'function_name': function.__name__, 'module_name': function.__module__, 'source': source}
else:
return {'function_name': function.__name__, 'module_name': function.__module__}
[docs]def draw_networkx_edge_labels(
G,
pos,
edge_labels=None,
label_pos=0.5,
font_size=10,
font_color="k",
font_family="sans-serif",
font_weight="normal",
alpha=None,
bbox=None,
horizontalalignment="center",
verticalalignment="center",
ax=None,
rotate=True,
clip_on=True,
rad=0
):
"""Draw edge labels.
Parameters
----------
G : graph
A networkx graph
pos : dictionary
A dictionary with nodes as keys and positions as values.
Positions should be sequences of length 2.
edge_labels : dictionary (default={})
Edge labels in a dictionary of labels keyed by edge two-tuple.
Only labels for the keys in the dictionary are drawn.
label_pos : float (default=0.5)
Position of edge label along edge (0=head, 0.5=center, 1=tail)
font_size : int (default=10)
Font size for text labels
font_color : string (default='k' black)
Font color string
font_weight : string (default='normal')
Font weight
font_family : string (default='sans-serif')
Font family
alpha : float or None (default=None)
The text transparency
bbox : Matplotlib bbox, optional
Specify text box properties (e.g. shape, color etc.) for edge labels.
Default is {boxstyle='round', ec=(1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0)}.
horizontalalignment : string (default='center')
Horizontal alignment {'center', 'right', 'left'}
verticalalignment : string (default='center')
Vertical alignment {'center', 'top', 'bottom', 'baseline', 'center_baseline'}
ax : Matplotlib Axes object, optional
Draw the graph in the specified Matplotlib axes.
rotate : bool (deafult=True)
Rotate edge labels to lie parallel to edges
clip_on : bool (default=True)
Turn on clipping of edge labels at axis boundaries
Returns
-------
dict
`dict` of labels keyed by edge
Examples
--------
>>> G = nx.dodecahedral_graph()
>>> edge_labels = nx.draw_networkx_edge_labels(G, pos=nx.spring_layout(G))
Also see the NetworkX drawing examples at
https://networkx.org/documentation/latest/auto_examples/index.html
See Also
--------
draw
draw_networkx
draw_networkx_nodes
draw_networkx_edges
draw_networkx_labels
"""
import matplotlib.pyplot as plt
import numpy as np
if ax is None:
ax = plt.gca()
if edge_labels is None:
labels = {(u, v): d for u, v, d in G.edges(data=True)}
else:
labels = edge_labels
text_items = {}
for (n1, n2), label in labels.items():
(x1, y1) = pos[n1]
(x2, y2) = pos[n2]
(x, y) = (
x1 * label_pos + x2 * (1.0 - label_pos),
y1 * label_pos + y2 * (1.0 - label_pos),
)
pos_1 = ax.transData.transform(np.array(pos[n1]))
pos_2 = ax.transData.transform(np.array(pos[n2]))
linear_mid = 0.5*pos_1 + 0.5*pos_2
d_pos = pos_2 - pos_1
rotation_matrix = np.array([(0, 1), (-1, 0)])
ctrl_1 = linear_mid + rad*rotation_matrix@d_pos
ctrl_mid_1 = 0.5*pos_1 + 0.5*ctrl_1
ctrl_mid_2 = 0.5*pos_2 + 0.5*ctrl_1
bezier_mid = 0.5*ctrl_mid_1 + 0.5*ctrl_mid_2
(x, y) = ax.transData.inverted().transform(bezier_mid)
if rotate:
# in degrees
angle = np.arctan2(y2 - y1, x2 - x1) / (2.0 * np.pi) * 360
# make label orientation "right-side-up"
if angle > 90:
angle -= 180
if angle < -90:
angle += 180
# transform data coordinate angle to screen coordinate angle
xy = np.array((x, y))
trans_angle = ax.transData.transform_angles(
np.array((angle,)), xy.reshape((1, 2))
)[0]
else:
trans_angle = 0.0
# use default box of white with white border
if bbox is None:
bbox = dict(boxstyle="round", ec=(
1.0, 1.0, 1.0), fc=(1.0, 1.0, 1.0))
if not isinstance(label, str):
label = str(label) # this makes "1" and 1 labeled the same
t = ax.text(
x,
y,
label,
size=font_size,
color=font_color,
family=font_family,
weight=font_weight,
alpha=alpha,
horizontalalignment=horizontalalignment,
verticalalignment=verticalalignment,
rotation=trans_angle,
transform=ax.transData,
bbox=bbox,
zorder=1,
clip_on=clip_on,
)
text_items[(n1, n2)] = t
ax.tick_params(
axis="both",
which="both",
bottom=False,
left=False,
labelbottom=False,
labelleft=False,
)
return text_items