"""
Download all NIRCam IMAGE I2D mosaics from GO-4598 (all filters).
Resumable: skips files already on disk.
"""
import os, sys, urllib3, requests, time
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
old_request = requests.Session.request
def _patched(self, *args, **kwargs):
    kwargs['verify'] = False
    return old_request(self, *args, **kwargs)
requests.Session.request = _patched

from astroquery.mast import Observations
try:
    Observations._portal_api_connection.session.verify = False
except Exception:
    pass

OUT_DIR = 'JWST_GO4598_NIRCAM'
os.makedirs(OUT_DIR, exist_ok=True)

# All NIRCam imaging filters
filters = ['F090W', 'F115W', 'F150W', 'F200W', 'F277W', 'F356W', 'F410M', 'F444W']

total_downloaded = 0
total_skipped = 0
total_bytes = 0
manifest = []

for filt in filters:
    print(f"\n=== Filter {filt} ===")
    obs = Observations.query_criteria(
        proposal_id='4598',
        instrument_name='NIRCAM/IMAGE',
        filters=filt
    )
    if len(obs) == 0:
        print(f"  No observations found for {filt}")
        continue
    print(f"  {len(obs)} observations")
    products = Observations.get_product_list(obs)
    i2d = Observations.filter_products(
        products,
        productType=['SCIENCE'],
        productSubGroupDescription=['I2D'],
        extension='fits'
    )
    print(f"  {len(i2d)} I2D products")
    for row in i2d:
        fn = row['productFilename']
        size = row['size'] if row['size'] else 0
        target = os.path.join(OUT_DIR, fn)
        if os.path.exists(target) and os.path.getsize(target) > 0.9 * size:
            total_skipped += 1
            continue
        uri = row['dataURI']
        print(f"  -> {fn} ({size/1e6:.1f} MB)", end=" ", flush=True)
        t0 = time.time()
        try:
            result = Observations.download_file(uri, local_path=target)
            elapsed = time.time() - t0
            actual = os.path.getsize(target) if os.path.exists(target) else 0
            total_downloaded += 1
            total_bytes += actual
            manifest.append((filt, fn, actual, elapsed))
            print(f"[{elapsed:.1f}s, {actual/1e6:.1f} MB]")
        except Exception as e:
            print(f"FAILED: {type(e).__name__}: {e}")

print(f"\n=== Summary ===")
print(f"Downloaded: {total_downloaded} files")
print(f"Skipped (already on disk): {total_skipped}")
print(f"Total bytes: {total_bytes/1e9:.2f} GB")

# Write manifest
with open(os.path.join(OUT_DIR, '_download_manifest.csv'), 'w') as f:
    f.write('filter,filename,bytes,download_time_s\n')
    for row in manifest:
        f.write(','.join(str(x) for x in row) + '\n')
print(f"\nManifest written to {OUT_DIR}\\_download_manifest.csv")
