"""
Direct urllib downloader for NIRCam COMBINED visit-level I2D mosaics.
Uses watchdog: aborts and retries if no data flows for 60s.
"""
import os, sys, urllib.request, urllib3, ssl, socket, time, threading, json
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
socket.setdefaulttimeout(60)
ctx = ssl.create_default_context()
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE

# Patch requests for astroquery
import requests
old_request = requests.Session.request
def _patched(self, *args, **kwargs):
    kwargs['verify'] = False
    if 'timeout' not in kwargs:
        kwargs['timeout'] = 60
    return old_request(self, *args, **kwargs)
requests.Session.request = _patched

from astroquery.mast import Observations
try:
    Observations._portal_api_connection.session.verify = False
except Exception:
    pass

OUT_DIR = 'JWST_GO4598_NIRCAM'
os.makedirs(OUT_DIR, exist_ok=True)

# Filters with priority on Cha-2025 long-wavelength + F200W completion
PRIORITY = ['F277W', 'F356W', 'F410M', 'F444W', 'F200W']

def watchdog_download(url, target, expected_size, stall_seconds=60):
    """Download with chunked read; abort if no bytes received within stall_seconds."""
    req = urllib.request.Request(url, headers={'User-Agent': 'BCR-FUM-UM/1.0'})
    last_progress = [time.time()]
    bytes_so_far = [0]
    aborted = [False]

    def watchdog():
        while True:
            time.sleep(5)
            if aborted[0]:
                return
            elapsed = time.time() - last_progress[0]
            if elapsed > stall_seconds:
                aborted[0] = True
                return

    wd = threading.Thread(target=watchdog, daemon=True)
    wd.start()

    try:
        with urllib.request.urlopen(req, context=ctx, timeout=30) as r:
            with open(target + '.tmp', 'wb') as f:
                while True:
                    if aborted[0]:
                        raise TimeoutError(f"Watchdog: no data for {stall_seconds}s")
                    chunk = r.read(262144)  # 256 KB
                    if not chunk:
                        break
                    f.write(chunk)
                    bytes_so_far[0] += len(chunk)
                    last_progress[0] = time.time()
        aborted[0] = True  # stop watchdog
        # Atomic rename
        if os.path.exists(target):
            os.remove(target)
        os.rename(target + '.tmp', target)
        return os.path.getsize(target), bytes_so_far[0]
    except Exception as e:
        aborted[0] = True
        if os.path.exists(target + '.tmp'):
            os.remove(target + '.tmp')
        raise e

# Find combined mosaics for each priority filter
print("=== Finding combined visit-level mosaics ===", flush=True)
all_targets = []
for filt in PRIORITY:
    print(f"\n--- {filt} ---", flush=True)
    try:
        obs = Observations.query_criteria(
            proposal_id='4598',
            instrument_name='NIRCAM/IMAGE',
            filters=filt
        )
        products = Observations.get_product_list(obs)
        i2d = Observations.filter_products(
            products,
            productType=['SCIENCE'],
            productSubGroupDescription=['I2D'],
            extension='fits'
        )
        # Filter to combined mosaics: pattern jw04598-oNNN_tNNN_*
        combined = [r for r in i2d if 'jw04598-' in r['productFilename'] and '_t' in r['productFilename']]
        for r in combined:
            all_targets.append({
                'filter': str(filt),
                'filename': str(r['productFilename']),
                'size': int(r['size']) if r['size'] else 0,
                'uri': str(r['dataURI'])
            })
            print(f"  Combined: {r['productFilename']}  {(r['size'] or 0)/1e6:.1f} MB", flush=True)
    except Exception as e:
        print(f"  Query failed for {filt}: {e}", flush=True)

total_expected = sum(t['size'] for t in all_targets)
print(f"\n=== Total combined mosaics: {len(all_targets)}, {total_expected/1e9:.2f} GB ===", flush=True)

# Save target list
with open('combined_mosaic_targets.json', 'w') as f:
    json.dump(all_targets, f, indent=2)

# Download with watchdog + retry
print(f"\n=== Downloading ===", flush=True)
ok_count = 0
fail_count = 0
total_bytes = 0
for t in all_targets:
    target = os.path.join(OUT_DIR, t['filename'])
    expected = t['size']
    if os.path.exists(target) and os.path.getsize(target) > 0.9 * expected:
        print(f"  SKIP {t['filename']} (already on disk)", flush=True)
        continue
    url = f"https://mast.stsci.edu/api/v0.1/Download/file?uri={t['uri']}"

    for attempt in range(3):
        t0 = time.time()
        print(f"  -> {t['filter']} {t['filename']} ({expected/1e6:.1f} MB) attempt {attempt+1}", flush=True)
        try:
            actual, bytes_read = watchdog_download(url, target, expected, stall_seconds=60)
            elapsed = time.time() - t0
            if actual > 0.9 * expected:
                ok_count += 1
                total_bytes += actual
                rate = (actual / 1e6) / max(elapsed, 0.01)
                print(f"     OK [{elapsed:.1f}s, {actual/1e6:.1f} MB, {rate:.1f} MB/s]", flush=True)
                break
            else:
                print(f"     SHORT ({actual/1e6:.1f} MB)", flush=True)
        except TimeoutError as e:
            elapsed = time.time() - t0
            print(f"     STALLED after {elapsed:.0f}s — retrying", flush=True)
        except Exception as e:
            elapsed = time.time() - t0
            print(f"     ERROR {type(e).__name__}: {e}", flush=True)
    else:
        fail_count += 1
        print(f"     FAILED after 3 attempts", flush=True)

print(f"\n=== Summary: OK={ok_count} FAIL={fail_count} Bytes={total_bytes/1e9:.2f} GB ===", flush=True)
