from matplotlib import pyplot as plt
from m4opt.utils.optimization import partition_graph
import networkx as nx
import numpy as np

graph = nx.triangular_lattice_graph(30, 50)
center = np.mean(list(nx.get_node_attributes(graph, "pos").values()), axis=0)
for node, data in graph.nodes(data=True):
    data["distance"] = np.ceil(np.sqrt(np.sum(np.square(node - center))) ** 3).astype(
        np.intp
    )

part = partition_graph(graph, 50, seed=42, node_weight="distance")
ax = plt.axes(aspect=1)
nx.draw(
    graph,
    ax=ax,
    pos=nx.get_node_attributes(graph, "pos"),
    node_size=50,
    node_color=part,
    cmap="prism",
)