from matplotlib import pyplot as plt
from m4opt.utils.optimization import partition_graph, partition_graph_color
import networkx as nx

graph = nx.triangular_lattice_graph(20, 40)
part = partition_graph(graph, 20, seed=42)
color = partition_graph_color(
    graph, part, strategy="connected_sequential", interchange=True)
ax = plt.axes(aspect=1)
nx.draw(
    graph,
    ax=ax,
    pos=nx.get_node_attributes(graph, "pos"),
    node_size=50,
    node_color=color[part],
    cmap="cool",
)