"""Module for visualizing the disjunctive graph of a job shop instance."""
import functools
from typing import Optional, Callable
import warnings
import copy
import matplotlib
import matplotlib.pyplot as plt
import networkx as nx
from networkx.drawing.nx_agraph import graphviz_layout
from job_shop_lib import JobShopInstance
from job_shop_lib.graphs import (
JobShopGraph,
EdgeType,
NodeType,
Node,
build_disjunctive_graph,
)
Layout = Callable[[nx.Graph], dict[str, tuple[float, float]]]
# This function could be improved by a function extraction refactoring
# (see `plot_gantt_chart`
# function as a reference in how to do it). That would solve the
# "too many locals" warning. However, this refactoring is not a priority at
# the moment. To compensate, sections are separated by comments.
# For the "too many arguments" warning no satisfactory solution was
# found. I believe is still better than using `**kwargs` and losing the
# function signature or adding a dataclass for configuration (it would add
# unnecessary complexity).
# pylint: disable=too-many-arguments, too-many-locals
[docs]
def plot_disjunctive_graph(
job_shop: JobShopGraph | JobShopInstance,
figsize: tuple[float, float] = (6, 4),
node_size: int = 1600,
title: Optional[str] = None,
layout: Optional[Layout] = None,
edge_width: int = 2,
font_size: int = 10,
arrow_size: int = 35,
alpha=0.95,
node_font_color: str = "white",
color_map: str = "Dark2_r",
draw_disjunctive_edges: bool = True,
) -> plt.Figure:
"""Returns a plot of the disjunctive graph of the instance."""
if isinstance(job_shop, JobShopInstance):
job_shop_graph = build_disjunctive_graph(job_shop)
else:
job_shop_graph = job_shop
# Set up the plot
# ----------------
plt.figure(figsize=figsize)
if title is None:
title = (
f"Disjunctive Graph Visualization: {job_shop_graph.instance.name}"
)
plt.title(title)
# Set up the layout
# -----------------
if layout is None:
layout = functools.partial(
graphviz_layout, prog="dot", args="-Grankdir=LR"
)
temp_graph = copy.deepcopy(job_shop_graph.graph)
# Remove disjunctive edges to get a better layout
temp_graph.remove_edges_from(
[
(u, v)
for u, v, d in job_shop_graph.graph.edges(data=True)
if d["type"] == EdgeType.DISJUNCTIVE
]
)
try:
pos = layout(temp_graph)
except ImportError:
warnings.warn(
"Default layout requires pygraphviz http://pygraphviz.github.io/. "
"Using spring layout instead.",
)
pos = nx.spring_layout(temp_graph)
# Draw nodes
# ----------
node_colors = [
_get_node_color(node)
for node in job_shop_graph.nodes
if not job_shop_graph.is_removed(node.node_id)
]
nx.draw_networkx_nodes(
job_shop_graph.graph,
pos,
node_size=node_size,
node_color=node_colors,
alpha=alpha,
cmap=matplotlib.colormaps.get_cmap(color_map),
)
# Draw edges
# ----------
conjunctive_edges = [
(u, v)
for u, v, d in job_shop_graph.graph.edges(data=True)
if d["type"] == EdgeType.CONJUNCTIVE
]
disjunctive_edges = [
(u, v)
for u, v, d in job_shop_graph.graph.edges(data=True)
if d["type"] == EdgeType.DISJUNCTIVE
]
nx.draw_networkx_edges(
job_shop_graph.graph,
pos,
edgelist=conjunctive_edges,
width=edge_width,
edge_color="black",
arrowsize=arrow_size,
)
if draw_disjunctive_edges:
nx.draw_networkx_edges(
job_shop_graph.graph,
pos,
edgelist=disjunctive_edges,
width=edge_width,
edge_color="red",
arrowsize=arrow_size,
)
# Draw node labels
# ----------------
operation_nodes = job_shop_graph.nodes_by_type[NodeType.OPERATION]
labels = {}
source_node = job_shop_graph.nodes_by_type[NodeType.SOURCE][0]
labels[source_node] = "S"
sink_node = job_shop_graph.nodes_by_type[NodeType.SINK][0]
labels[sink_node] = "T"
for operation_node in operation_nodes:
if job_shop_graph.is_removed(operation_node.node_id):
continue
labels[operation_node] = (
f"m={operation_node.operation.machine_id}\n"
f"d={operation_node.operation.duration}"
)
nx.draw_networkx_labels(
job_shop_graph.graph,
pos,
labels=labels,
font_color=node_font_color,
font_size=font_size,
font_family="sans-serif",
)
# Final touches
# -------------
plt.axis("off")
plt.tight_layout()
# Create a legend to indicate the meaning of the edge colors
conjunctive_patch = matplotlib.patches.Patch(
color="black", label="conjunctive edges"
)
disjunctive_patch = matplotlib.patches.Patch(
color="red", label="disjunctive edges"
)
# Add to the legend the meaning of m and d
text = "m = machine_id\nd = duration"
extra = matplotlib.patches.Rectangle(
(0, 0),
1,
1,
fc="w",
fill=False,
edgecolor="none",
linewidth=0,
label=text,
)
plt.legend(
handles=[conjunctive_patch, disjunctive_patch, extra],
loc="upper left",
bbox_to_anchor=(1.05, 1),
borderaxespad=0.0,
)
return plt.gcf()
def _get_node_color(node: Node) -> int:
"""Returns the color of the node."""
if node.node_type == NodeType.SOURCE:
return -1
if node.node_type == NodeType.SINK:
return -1
if node.node_type == NodeType.OPERATION:
return node.operation.machine_id
raise ValueError("Invalid node type.")