Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 41 additions & 11 deletions socat/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,30 +43,56 @@ async def init_interp(self, session: AsyncSession) -> None:
if type(self.source) is RegisteredFixedSource:
self.ra_unit = self.source.position.ra.unit
self.dec_unit = self.source.position.dec.unit
self.flux_unit = self.source.flux.unit
self.do_flux = self.source.flux is not None
self.flux_unit = self.source.flux.unit if self.do_flux else None
self.interp = lambda _: (
self.source.position.ra.value,
self.source.position.dec.value,
self.source.flux.value,
(
self.source.position.ra.value,
self.source.position.dec.value,
self.source.flux.value,
)
if self.do_flux
else (
self.source.position.ra.value,
self.source.position.dec.value,
)
)

elif type(self.source) is SolarSystemObject:
ephems = await get_ephem_points(
self.source, t_min=self.t_min, t_max=self.t_max, session=session
)
x = np.zeros(len(ephems))
y = np.zeros((len(ephems), 3))

self.do_flux = True
for ephem in ephems: # TODO: please someone have a better way to do this than looping through the ephems twice
if ephem.flux is None:
self.do_flux = False
break

if self.do_flux:
y = np.zeros((len(ephems), 3))
else:
y = np.zeros((len(ephems), 2))

for i, ephem in enumerate(ephems):
x[i] = ephem.time.unix
y[i] = (
ephem.position.ra.value,
ephem.position.dec.value,
ephem.flux.value,
(
ephem.position.ra.value,
ephem.position.dec.value,
ephem.flux.value,
)
if self.do_flux
else (
ephem.position.ra.value,
ephem.position.dec.value,
)
)

self.ra_unit = ephem.position.ra.unit # This assumes all ephem points have same units but this should probably be enforced upstream anyway.
self.dec_unit = ephem.position.dec.unit
self.flux_unit = ephem.flux.unit
self.flux_unit = ephem.flux.unit if self.do_flux else None
self.interp = make_interp_spline(x, y, k=1)

# @lru_cache(maxsize=128) # This can cause memory leaks so we might not want it
Expand Down Expand Up @@ -100,8 +126,12 @@ def at_time(self, t: Time) -> tuple[ICRS, Quantity]:
f"Error, requested t={t} outside initialized bounds {self.t_min}-{self.t_max}"
)

ra, dec, flux = self.interp(t.unix)
if self.do_flux:
ra, dec, flux = self.interp(t.unix)
flux = flux * self.flux_unit if flux != 0 else None
else:
ra, dec = self.interp(t.unix)
flux = None
position = ICRS(ra=ra * self.ra_unit, dec=dec * self.dec_unit)
flux = flux * self.flux_unit

return (position, flux)
56 changes: 56 additions & 0 deletions tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,59 @@ async def test_get_box(database_async_sessionmaker):
await core.delete_sso(source.sso_id, session=session)
for source in fixed_sources:
await core.delete_source(source.source_id, session=session)


@pytest.mark.asyncio
async def test_none_behavior(database_async_sessionmaker):
# Check that if we have a source with no flux, then the interp returns just ra/dec and not ra/dec/flux
async with database_async_sessionmaker() as session:
position = ICRS(1 * u.deg, 1 * u.deg)
source = await core.create_source(
position,
session=session,
name="mySrc",
flux=None,
)
gen = generator.SourceGenerator(
source, Time("2025-01-01T00:00:00.00"), Time("2026-01-01T00:00:00.00")
)
await gen.init_interp(session=session)

position, flux = gen.at_time(t=Time("2025-06-01T00:00:00.000000"))

assert position.ra.value == 1
assert position.dec.value == 1
assert flux is None
assert gen.flux_unit is None
assert gen.do_flux is False

async with database_async_sessionmaker() as session:
await core.delete_source(source.source_id, session=session)

async with database_async_sessionmaker() as session:
sso = await core.create_sso(name="Davida", MPC_id=511, session=session)
for i in range(10):
position = ICRS(i * u.deg, 1.5 * i * u.deg)
time = Time("2025-02-01T00:00:00.00") + (100 * i) * u.s
await core.create_ephem(
sso_id=sso.sso_id,
MPC_id=511,
name="Davida",
time=time,
position=position,
flux=None,
session=session,
)
gen = generator.SourceGenerator(
sso, Time("2025-01-01T00:00:00.00"), Time("2026-01-01T00:00:00.00")
)
await gen.init_interp(session=session)

position, flux = gen.at_time(Time("2025-02-01 00:04:10.000000"))

assert flux is None
assert gen.flux_unit is None
assert gen.do_flux is False

async with database_async_sessionmaker() as session:
await core.delete_sso(sso.sso_id, session=session)
Loading