from matplotlib import pyplot as plt
from m4opt.utils.optimization import partition_graph
import networkx as nx

graph = nx.triangular_lattice_graph(10, 20)
part = partition_graph(graph, 5, seed=42)
ax = plt.axes(aspect=1)
nx.draw(
    graph,
    ax=ax,
    pos=nx.get_node_attributes(graph, "pos"),
    node_size=50,
    node_color=part,
    cmap="prism",
)