from astropy.coordinates import EarthLocation, Galactic, SkyCoord
from astropy.time import Time
from astropy import units as u
from matplotlib import pyplot as plt
from m4opt.synphot.background import GalacticBackground
from m4opt.synphot.background._core import BACKGROUND_SOLID_ANGLE
from m4opt.synphot import observing
import numpy as np
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
from matplotlib.lines import Line2D


galex_wavelengths = [1539, 2316] * u.angstrom
lats = np.arange(-90, 100, 10) * u.deg
wavelengths = np.linspace(1000, 3000) * u.angstrom
spectrum = GalacticBackground()

target_coord = SkyCoord(0 * u.deg, lats, frame=Galactic())
# Observer location and obstime are arbitrary
observer_location = EarthLocation(0 * u.m, 0 * u.m, 0 * u.m)
obstime = Time("2024-01-01")

with observing(observer_location, target_coord[:, np.newaxis], obstime):
    flux_density = (spectrum(wavelengths) / BACKGROUND_SOLID_ANGLE).to(
        u.photon * u.cm**-2 * u.s**-1 * u.sr**-1 * u.angstrom**-1
    )

colormap = ScalarMappable(Normalize(vmin=0, vmax=90), plt.get_cmap("cividis"))
fig, axs = plt.subplots(
    1, 3, figsize=(8, 3), width_ratios=(3, 3, 1), sharex=True, sharey=True
)

axs[0].set_title("Galactic latitude $b$ < 0°")
axs[1].set_title("Galactic latitude $b$ > 0°")

for ax, sign in zip(axs, [-1, 1]):
    keep = np.sign(lats) == sign
    for y, lat in zip(flux_density[keep], lats[keep]):
        ax.plot(wavelengths, y, color=colormap.to_rgba(sign * lat.value))
    ax.set_xlabel("Wavelength")

    ax_twin = ax.twiny()
    ax_twin.set_xlim(1000, 3000)
    ax_twin.set_xticks(galex_wavelengths.value)
    ax_twin.set_xticklabels(["FUV", "NUV"])
    ax_twin.grid()

axs[0].set_ylim(0)
axs[0].set_xlim(1000, 3000)
axs[0].set_ylabel(f"Background ({flux_density.unit:latex_inline})")

axs[-1].set_frame_on(False)
plt.setp(
    axs[-1].xaxis.get_major_ticks() + axs[-1].yaxis.get_major_ticks(), visible=False
)
axs[-1].legend(
    [Line2D([], [], color=colormap.to_rgba(lat)) for lat in range(10, 100, 10)],
    [f"{lat}°" for lat in range(10, 100, 10)],
    mode="expand",
    title="$|b|$",
    loc="upper left",
    borderaxespad=0,
)

fig.tight_layout()