from astropy import units as u
from astropy.coordinates import EarthLocation, ICRS
from astropy.time import Time
from astropy_healpix import HEALPix
from matplotlib import pyplot as plt
from m4opt.missions import ultrasat
from m4opt.synphot import observing
import numpy as np
from synphot import SourceSpectrum
from synphot.models import BlackBody1D

dwell = u.def_unit("dwell", 300 * u.s)
exptime = 3 * dwell
obstime = Time("2024-01-01") + np.linspace(0, 1) * u.year
hpx = HEALPix(8, frame=ICRS())
target_coords = hpx.healpix_to_skycoord(np.arange(hpx.npix))
observer_location = EarthLocation(0 * u.m, 0 * u.m, 0 * u.m)
snr = np.geomspace(1e1, 1e4)

with observing(
    observer_location,
    target_coords[:, np.newaxis, np.newaxis],
    obstime[np.newaxis, :, np.newaxis],
):
    limmag_g, limmag_m = [
        np.median(
            ultrasat.detector.get_limmag(
                snr=snr,
                exptime=exptime,
                source_spectrum=SourceSpectrum(
                    BlackBody1D, temperature=temp * u.K
                ).normalize(
                    renorm_val=0 * u.ABmag,
                    band=ultrasat.detector.bandpasses["NUV"],
                    vegaspec=None,
                ),
                bandpass="NUV",
            ).to_value(u.mag),
            axis=[0, 1],
        )
        for temp in [5000, 3000]
    ]

fig, ax = plt.subplots()
ax.plot(limmag_g, snr, label="G Dwarf", color="g")
ax.plot(limmag_m, snr, label="M Dwarf", color="r")
ax.set_xlim(10, 22)
ax.set_ylim(1e1, 1e4)
ax.set_yscale("log")
ax.set_xlabel("AB Magnitude")
ax.set_ylabel("SNR")
ax.grid(True)
ax.legend()