from astropy import units as u
from astropy.coordinates import EarthLocation, ICRS
from astropy_healpix import HEALPix
from astropy.time import Time
from matplotlib import pyplot as plt
from m4opt.missions import uvex
from m4opt.synphot import observing
import numpy as np
from synphot import ConstFlux1D, SourceSpectrum

dwell = u.def_unit("dwell", 900 * u.s)
exptime = np.arange(1, 11) * dwell
obstime = Time("2024-01-01") + np.linspace(0, 1) * u.year
hpx = HEALPix(128, frame=ICRS())
target_coords = hpx.healpix_to_skycoord(np.arange(hpx.npix))
observer_location = EarthLocation(0 * u.m, 0 * u.m, 0 * u.m)

limmags = []
for filt in uvex.detector.bandpasses.keys():
    with observing(
        observer_location,
        target_coords[np.newaxis, :, np.newaxis],
        obstime[np.newaxis, np.newaxis, :],
    ):
        limmags.append(
            uvex.detector.get_limmag(
                5 * np.sqrt(dwell / exptime[:, np.newaxis, np.newaxis]),
                1 * dwell,
                SourceSpectrum(ConstFlux1D, amplitude=0 * u.ABmag),
                filt,
            ).to_value(u.mag)
        )
median_limmags = np.median(limmags, axis=[2, 3])

ax = plt.axes()
ax.set_xlim(1, 10)
ax.set_ylim(24.5, 26.5)
ax.invert_yaxis()
for filt, limmag in zip(uvex.detector.bandpasses.keys(), median_limmags):
    ax.plot(exptime, limmag, "-o", label=filt)
ax.legend()
ax.set_xlabel("Number of stacked 900 s dwells")
ax.set_ylabel(r"5-$\sigma$ Limiting magnitude (AB)")
plt.savefig("test.png")