nixfleet_control_plane/server/
mod.rs

1//! Long-running TLS server: router + listener + reconcile loop + polls.
2
3mod middleware;
4mod route_error;
5mod routes;
6mod state;
7
8pub use state::{AppState, ClosureUpstream, IssuancePaths, ServeArgs, VerifiedFleetSnapshot};
9
10use std::sync::Arc;
11use std::time::Duration;
12
13use axum::Router;
14use axum::body::Body;
15use axum::http::Request as HttpRequest;
16use axum::middleware::Next;
17use axum::routing::{get, post};
18use tokio_util::sync::CancellationToken;
19
20use crate::auth::auth_cn::MtlsAcceptor;
21
22/// 30s = systemd `TimeoutStopSec=` default; stay under to avoid SIGKILL.
23const TASK_SHUTDOWN_DEADLINE: Duration = Duration::from_secs(30);
24
25/// Listener drain budget; remaining 5s of `TASK_SHUTDOWN_DEADLINE` is for background drain.
26const HTTP_DRAIN_DEADLINE: Duration = Duration::from_secs(25);
27
28/// `/healthz` outside `/v1`; `/v1/enroll` is anonymous; all other `/v1/*` require mTLS.
29/// `/v1/*` is gated by `require_ready_layer` (#95) so agents get 503 + Retry-After
30/// until the first signed artifact is verified - no stale-state serving.
31fn build_router(state: Arc<AppState>) -> Router {
32    let strict = state.strict;
33    let auth_state = state.clone();
34    let ready_state = state.clone();
35
36    let anonymous_v1 = Router::new().route("/v1/enroll", post(routes::enrollment::enroll));
37
38    let authenticated_v1 = Router::new()
39        .route("/v1/whoami", get(routes::status::whoami))
40        // RFC-0005 §4 wire protocol (replaces the v0.1 checkin / confirm
41        // pull pipeline — agents now POST events as they happen and
42        // long-poll for dispatches).
43        .route("/v1/agent/events", post(routes::events::events))
44        .route("/v1/agent/heartbeat", post(routes::heartbeat::heartbeat))
45        .route("/v1/agent/dispatch", get(routes::dispatch::dispatch))
46        .route(
47            "/v1/agent/closure/{hash}",
48            get(routes::status::closure_proxy),
49        )
50        .route("/v1/agent/renew", post(routes::enrollment::renew))
51        .route("/v1/channels/{name}", get(routes::status::channel_status))
52        .route("/v1/hosts", get(routes::status::hosts_status))
53        .route("/metrics", get(routes::metrics::metrics_handler))
54        .route("/v1/rollouts", get(routes::rollouts::list_active))
55        .route("/v1/deferrals", get(routes::deferrals::list))
56        .route("/v1/fleet.resolved", get(routes::fleet::artifact))
57        .route("/v1/fleet.resolved/sig", get(routes::fleet::signature))
58        .route("/v1/rollouts/{rolloutId}", get(routes::rollouts::manifest))
59        .route(
60            "/v1/rollouts/{rolloutId}/sig",
61            get(routes::rollouts::signature),
62        )
63        .route(
64            "/v1/rollouts/{rolloutId}/lifecycle",
65            get(routes::rollouts::lifecycle),
66        )
67        .route(
68            "/v1/rollouts/{rolloutId}/hosts",
69            get(routes::rollouts::hosts),
70        )
71        .route(
72            "/v1/rollouts/{rolloutId}/events",
73            get(routes::rollouts::events),
74        )
75        .layer(axum::middleware::from_fn(move |req, next| {
76            let s = auth_state.clone();
77            async move { middleware::require_cn_layer(s, req, next).await }
78        }));
79
80    let v1_routes = anonymous_v1
81        .merge(authenticated_v1)
82        .layer(axum::middleware::from_fn(move |req, next| {
83            version_layer(strict, req, next)
84        }))
85        .layer(axum::middleware::from_fn(move |req, next| {
86            let s = ready_state.clone();
87            async move { middleware::require_ready_layer(s, req, next).await }
88        }));
89
90    Router::new()
91        .route("/healthz", get(routes::health::healthz))
92        .merge(v1_routes)
93        .with_state(state)
94}
95
96async fn version_layer(
97    strict: bool,
98    req: HttpRequest<Body>,
99    next: Next,
100) -> Result<axum::response::Response, axum::http::StatusCode> {
101    middleware::protocol_version_middleware(strict, req, next).await
102}
103
104pub async fn serve(args: ServeArgs) -> anyhow::Result<()> {
105    if args.strict {
106        let mut missing: Vec<&str> = Vec::new();
107        if args.client_ca.is_none() {
108            missing.push("--client-ca (mTLS verification disabled - TLS-only mode)");
109        }
110        if args.revocations.is_none() {
111            missing.push("--revocations-{artifact,signature}-url (revocations polling disabled - previously-revoked certs become valid again after CP rebuild)");
112        }
113        if args.bootstrap_nonces.is_none() {
114            missing.push("--bootstrap-nonces-{artifact,signature}-url (bootstrap-nonces polling disabled - replay-after-DB-wipe protection absent, nixfleet#96)");
115        }
116        // RFC-0010 §1.5.1: file-backed CA leaves signing material on disk;
117        // production deployments MUST use the TPM backend or opt-in explicitly.
118        let tpm_configured = args.tpm_ca_pubkey_raw.is_some() && args.tpm_ca_sign_wrapper.is_some();
119        let file_only = args.fleet_ca_key.is_some() && !tpm_configured;
120        if file_only && !args.allow_file_ca_key {
121            missing.push("--tpm-ca-pubkey-raw + --tpm-ca-sign-wrapper (file-backed --fleet-ca-key keeps signing material on disk; pass --allow-file-ca-key to opt out, RFC-0010 §1.5.1)");
122        }
123        if !missing.is_empty() {
124            anyhow::bail!(
125                "--strict refuses to start: the following security flags are unset:\n  - {}\n\
126                 Either set the missing flags or drop --strict for development.",
127                missing.join("\n  - "),
128            );
129        }
130    }
131
132    let db = if let Some(path) = &args.db_path {
133        let db = crate::db::Db::open(path)?;
134        db.migrate()?;
135        tracing::info!(path = %path.display(), "sqlite opened + migrated");
136        Some(Arc::new(db))
137    } else {
138        None
139    };
140
141    let closure_upstream = if let Some(base_url) = &args.closure_upstream {
142        let client = reqwest::Client::builder()
143            .timeout(Duration::from_secs(30))
144            .build()
145            .map_err(|e| anyhow::anyhow!("build closure proxy client: {e}"))?;
146        Some(ClosureUpstream {
147            base_url: base_url.clone(),
148            client,
149        })
150    } else {
151        None
152    };
153    let revocations_required = args.revocations.is_some();
154    let bootstrap_nonces_required = args.bootstrap_nonces.is_some();
155    let app_state = AppState {
156        db: db.clone(),
157        confirm_deadline_secs: args.confirm_deadline_secs,
158        closure_upstream,
159        rollouts_dir: args.rollouts_dir.clone(),
160        rollouts_source: args.rollouts_source.clone(),
161        channel_refs_source: args.channel_refs.clone(),
162        strict: args.strict,
163        agent_cn_suffix: args.agent_cn_suffix.clone(),
164        agent_cert_validity: args.agent_cert_validity,
165        revocations_required,
166        bootstrap_nonces_required,
167        ..Default::default()
168    };
169    if args.mark_ready_at_startup {
170        // Test-only escape hatch - see ServeArgs::mark_ready_at_startup.
171        app_state
172            .artifact_primed
173            .store(true, std::sync::atomic::Ordering::Release);
174        app_state
175            .revocations_primed
176            .store(true, std::sync::atomic::Ordering::Release);
177        app_state
178            .bootstrap_nonces_primed
179            .store(true, std::sync::atomic::Ordering::Release);
180    }
181    if let Some(nonces) = args.initial_nonces {
182        // Test-only escape hatch - see ServeArgs::initial_nonces.
183        *app_state.allowed_nonces.write().await = nonces;
184    }
185    let state = Arc::new(app_state);
186
187    let cancel = CancellationToken::new();
188    let mut bg_handles: Vec<tokio::task::JoinHandle<()>> = Vec::new();
189
190    if let Some(db_arc) = db.clone() {
191        bg_handles.push(crate::timers::prune_timer::spawn(
192            cancel.clone(),
193            db_arc,
194            args.db_path.clone(),
195        ));
196    }
197
198    *state.issuance_paths.write().await = IssuancePaths {
199        fleet_ca_cert: args.fleet_ca_cert.clone(),
200        fleet_ca_key: args.fleet_ca_key.clone(),
201        audit_log: args.audit_log_path.clone(),
202        trust_path: args.trust_path.clone(),
203    };
204
205    // CA signer built once; TPM (pubkey + wrapper) wins over file (fleet_ca_key).
206    if let Some(cert_path) = args.fleet_ca_cert.as_ref() {
207        let signer = crate::auth::issuance::build_signer_from_args(
208            cert_path,
209            args.tpm_ca_pubkey_raw.as_deref(),
210            args.tpm_ca_sign_wrapper.as_deref(),
211            args.fleet_ca_key.as_deref(),
212        );
213        *state.ca_signer.write().await = signer;
214    }
215
216    // Pre-listener prime via manifest_poll. Synchronous one-shot fetch +
217    // verify so the routes that read `state.verified_fleet` (channel_status,
218    // enrollment) see a populated snapshot on the first request. After
219    // this, the manifest_poll worker (spawned by runtime::spawn below)
220    // keeps the snapshot fresh on its 30 s tick.
221    {
222        let clock: nixfleet_proto::clock::ClockHandle =
223            std::sync::Arc::new(nixfleet_proto::clock::SystemClock::new());
224        match tokio::time::timeout(
225            Duration::from_secs(20),
226            crate::runtime::workers::manifest_poll::prime_blocking(&state, &clock),
227        )
228        .await
229        {
230            Ok(Ok(true)) => {
231                tracing::info!(
232                    target: "cp_boot",
233                    "primed verified-fleet from channel-refs source before opening listener",
234                );
235            }
236            Ok(Ok(false)) => {
237                // No channel_refs_source configured — fall through to file
238                // fallback below.
239            }
240            Ok(Err(err)) => {
241                tracing::warn!(
242                    error = %err,
243                    "manifest_poll prime failed; daemon will keep retrying via the worker loop",
244                );
245            }
246            Err(_) => {
247                tracing::warn!(
248                    "manifest_poll prime timed out; daemon will keep retrying via the worker loop",
249                );
250            }
251        }
252    }
253
254    // File-fallback prime: when no upstream channel-refs source is
255    // configured (or its prime failed), verify the bundled --artifact /
256    // --signature pair and populate verified_fleet. Keeps the offline-boot
257    // + test fixture paths alive.
258    if state.verified_fleet.read().await.is_none()
259        && let Some((fleet, fleet_hash, artifact_bytes, signature_bytes)) =
260            prime_from_artifact_files(
261                &args.artifact_path,
262                &args.signature_path,
263                &args.trust_path,
264                args.freshness_window,
265            )
266    {
267        *state.verified_fleet.write().await = Some(VerifiedFleetSnapshot {
268            fleet: Arc::new(fleet),
269            fleet_resolved_hash: fleet_hash,
270            artifact_bytes,
271            signature_bytes,
272        });
273        state
274            .artifact_primed
275            .store(true, std::sync::atomic::Ordering::Release);
276        tracing::info!(
277            target: "cp_boot",
278            "primed verified-fleet from --artifact / --signature files",
279        );
280    }
281
282    if let (Some(revocations_source), Some(db)) = (args.revocations.clone(), state.db.clone()) {
283        bg_handles.push(crate::polling::revocations_poll::spawn(
284            cancel.clone(),
285            db,
286            revocations_source,
287            state.revocations_primed.clone(),
288        ));
289    }
290
291    if let Some(bootstrap_nonces_source) = args.bootstrap_nonces.clone() {
292        bg_handles.push(crate::polling::bootstrap_nonces_poll::spawn(
293            cancel.clone(),
294            state.allowed_nonces.clone(),
295            bootstrap_nonces_source,
296            state.bootstrap_nonces_primed.clone(),
297        ));
298    }
299
300    // RFC-0006 §7.2 runtime: MPSC reducer + applier + workers + event_log
301    // writer. Sole CP-side processing path; the v0.1 reconcile loop is
302    // gone.
303    {
304        let clock: nixfleet_proto::clock::ClockHandle =
305            std::sync::Arc::new(nixfleet_proto::clock::SystemClock::new());
306        let rt = crate::runtime::spawn(cancel.clone(), state.clone(), clock);
307        // Publish channels for the new /v1/agent/{events,heartbeat,dispatch}
308        // route handlers. `set` returns Err if already set; that's
309        // impossible during normal startup (called exactly once) but if a
310        // future test harness double-initialises, we just drop the second
311        // call.
312        let _ = state.runtime_input_tx.set(rt.input_tx.clone());
313        let _ = state.runtime_event_log_tx.set(rt.event_log_tx.clone());
314        for handle in rt.into_join_handles() {
315            bg_handles.push(handle);
316        }
317    }
318
319    // Process-global Prometheus recorder. Installs once; counter macros
320    // (record_compliance_event etc) silently no-op until then.
321    #[cfg(feature = "metrics")]
322    crate::metrics::install_recorder();
323
324    let app = build_router(state.clone());
325
326    let tls_config =
327        crate::tls::build_server_config(&args.tls_cert, &args.tls_key, args.client_ca.as_deref())?;
328    let rustls_config = axum_server::tls_rustls::RustlsConfig::from_config(Arc::new(tls_config));
329
330    let rustls_acceptor = axum_server::tls_rustls::RustlsAcceptor::new(rustls_config);
331    let mtls_acceptor = MtlsAcceptor::new(rustls_acceptor);
332
333    let mode = if args.client_ca.is_some() {
334        "TLS+mTLS"
335    } else {
336        tracing::warn!(
337            "control plane started without --client-ca: /v1/* endpoints will reject all clients with 401. \
338             Pass --client-ca to enable mTLS - recommended for any production deployment."
339        );
340        "TLS-only"
341    };
342    let ready = state.is_ready();
343    tracing::info!(
344        listen = %args.listen,
345        %mode,
346        ready,
347        "control plane listening{}",
348        if ready { "" } else { "; /v1/* will return 503 until first artifact verified" },
349    );
350
351    let server_handle = axum_server::Handle::new();
352    let signal_handle = server_handle.clone();
353    let signal_cancel = cancel.clone();
354    // LOADBEARING: graceful_shutdown FIRST (drains in-flight HTTP), THEN
355    // cancel the token (signals background tasks). Reversing causes timers
356    // to abort mid-write while requests are still completing.
357    tokio::spawn(async move {
358        if let Err(err) = tokio::signal::ctrl_c().await {
359            tracing::warn!(error = %err, "ctrl_c handler install failed; relying on hard shutdown");
360            return;
361        }
362        tracing::info!(target: "shutdown", "graceful shutdown initiated");
363        signal_handle.graceful_shutdown(Some(HTTP_DRAIN_DEADLINE));
364        signal_cancel.cancel();
365    });
366
367    axum_server::bind(args.listen)
368        .acceptor(mtls_acceptor)
369        .handle(server_handle)
370        .serve(app.into_make_service())
371        .await?;
372
373    cancel.cancel();
374    if let Err(err) = drain_background_tasks(bg_handles).await {
375        tracing::warn!(error = %err, "background task drain incomplete");
376    }
377    Ok(())
378}
379
380/// File-fallback verify+parse for the --artifact / --signature CLI args.
381/// Synchronous; called once at startup before the listener opens. Hash is
382/// computed against the received bytes (not a re-serialised parsed struct)
383/// so additive schema changes the CP's proto doesn't yet know about don't
384/// shift the rolloutId anchor — same load-bearing reason as in
385/// `manifest_poll::poll_once`.
386fn prime_from_artifact_files(
387    artifact_path: &std::path::Path,
388    signature_path: &std::path::Path,
389    trust_path: &std::path::Path,
390    freshness_window: Duration,
391) -> Option<(nixfleet_proto::FleetResolved, String, Vec<u8>, Vec<u8>)> {
392    let artifact = std::fs::read(artifact_path).ok()?;
393    let signature = std::fs::read(signature_path).ok()?;
394    let trust_raw = std::fs::read_to_string(trust_path).ok()?;
395    let trust: nixfleet_proto::TrustConfig = serde_json::from_str(&trust_raw).ok()?;
396    let now = chrono::Utc::now();
397    let trusted_keys = trust.ci_release_key.active_keys_at(now);
398    let reject_before = trust.ci_release_key.reject_before;
399    let verified = nixfleet_reconciler::verify_artifact(
400        &artifact,
401        &signature,
402        &trusted_keys,
403        now,
404        freshness_window,
405        reject_before,
406    )
407    .ok()?;
408    let fleet_hash = nixfleet_reconciler::canonical_hash_from_bytes(&artifact).ok()?;
409    Some((verified.into_inner(), fleet_hash, artifact, signature))
410}
411
412/// Tasks past `TASK_SHUTDOWN_DEADLINE` are abandoned (handles dropped -> abort).
413async fn drain_background_tasks(handles: Vec<tokio::task::JoinHandle<()>>) -> anyhow::Result<()> {
414    let total = handles.len();
415    let drain_fut = async move {
416        for handle in handles {
417            if let Err(err) = handle.await
418                && !err.is_cancelled()
419            {
420                tracing::warn!(error = %err, "background task panicked during shutdown");
421            }
422        }
423    };
424    match tokio::time::timeout(TASK_SHUTDOWN_DEADLINE, drain_fut).await {
425        Ok(()) => {
426            tracing::info!(target: "shutdown", tasks = total, "all background tasks shut down");
427            Ok(())
428        }
429        Err(_) => {
430            anyhow::bail!(
431                "background task drain exceeded {TASK_SHUTDOWN_DEADLINE:?}; forcing exit"
432            );
433        }
434    }
435}
436
437#[cfg(test)]
438mod strict_mode_tests {
439    use super::*;
440    use std::path::PathBuf;
441
442    fn minimal_serve_args(strict: bool, client_ca: Option<PathBuf>) -> ServeArgs {
443        ServeArgs {
444            tls_cert: PathBuf::from("/dev/null"),
445            tls_key: PathBuf::from("/dev/null"),
446            client_ca,
447            artifact_path: PathBuf::from("/dev/null"),
448            signature_path: PathBuf::from("/dev/null"),
449            trust_path: PathBuf::from("/dev/null"),
450            observed_path: PathBuf::from("/dev/null"),
451            strict,
452            ..Default::default()
453        }
454    }
455
456    #[tokio::test]
457    async fn strict_bails_when_client_ca_unset() {
458        let err = serve(minimal_serve_args(true, None)).await.unwrap_err();
459        let msg = format!("{err}");
460        assert!(
461            msg.contains("--client-ca"),
462            "expected client-ca hint in strict bail; got: {msg}",
463        );
464        assert!(
465            msg.contains("--strict refuses to start"),
466            "expected strict-prefixed message; got: {msg}",
467        );
468    }
469
470    #[tokio::test]
471    async fn strict_bails_when_revocations_unset() {
472        let err = serve(minimal_serve_args(true, Some(PathBuf::from("/dev/null"))))
473            .await
474            .unwrap_err();
475        let msg = format!("{err}");
476        assert!(
477            msg.contains("--revocations"),
478            "expected revocations hint in strict bail; got: {msg}",
479        );
480    }
481
482    #[tokio::test]
483    async fn strict_bails_when_bootstrap_nonces_unset() {
484        let mut args = minimal_serve_args(true, Some(PathBuf::from("/dev/null")));
485        // Provide revocations so that only the bootstrap-nonces check fires.
486        args.revocations = Some(crate::polling::revocations_poll::RevocationsSource {
487            artifact_url: "http://localhost/revocations.json".into(),
488            signature_url: "http://localhost/revocations.json.sig".into(),
489            token_file: None,
490            trust_path: PathBuf::from("/dev/null"),
491            freshness_window: std::time::Duration::from_secs(3600),
492        });
493        let err = serve(args).await.unwrap_err();
494        let msg = format!("{err}");
495        assert!(
496            msg.contains("--bootstrap-nonces"),
497            "expected bootstrap-nonces hint in strict bail; got: {msg}",
498        );
499        assert!(
500            msg.contains("--strict refuses to start"),
501            "expected strict-prefixed message; got: {msg}",
502        );
503    }
504
505    #[tokio::test]
506    async fn non_strict_does_not_bail_at_startup() {
507        let err = serve(minimal_serve_args(false, None)).await.unwrap_err();
508        let msg = format!("{err}");
509        assert!(
510            !msg.contains("--strict refuses to start"),
511            "non-strict mode should not emit the strict-mode error; got: {msg}",
512        );
513    }
514
515    /// Fixture: minimal args that satisfy every existing strict gate
516    /// EXCEPT the CA backend gate. Lets each CA test isolate its assertion.
517    fn args_satisfying_existing_strict_gates(strict: bool) -> ServeArgs {
518        let mut args = minimal_serve_args(strict, Some(PathBuf::from("/dev/null")));
519        args.revocations = Some(crate::polling::revocations_poll::RevocationsSource {
520            artifact_url: "http://localhost/revocations.json".into(),
521            signature_url: "http://localhost/revocations.json.sig".into(),
522            token_file: None,
523            trust_path: PathBuf::from("/dev/null"),
524            freshness_window: std::time::Duration::from_secs(3600),
525        });
526        args.bootstrap_nonces = Some(
527            crate::polling::bootstrap_nonces_poll::BootstrapNoncesSource {
528                artifact_url: "http://localhost/bootstrap-nonces.json".into(),
529                signature_url: "http://localhost/bootstrap-nonces.json.sig".into(),
530                token_file: None,
531                trust_path: PathBuf::from("/dev/null"),
532                freshness_window: std::time::Duration::from_secs(3600),
533            },
534        );
535        args
536    }
537
538    /// RFC-0010 §1.5.1: `--strict` + file-only CA + no opt-in must refuse to start.
539    #[tokio::test]
540    async fn strict_bails_when_file_ca_only_without_opt_in() {
541        let mut args = args_satisfying_existing_strict_gates(true);
542        args.fleet_ca_key = Some(PathBuf::from("/dev/null"));
543        let err = serve(args).await.unwrap_err();
544        let msg = format!("{err}");
545        assert!(
546            msg.contains("--tpm-ca-pubkey-raw"),
547            "expected TPM hint in strict bail; got: {msg}",
548        );
549        assert!(
550            msg.contains("RFC-0010 §1.5.1"),
551            "expected RFC pointer in strict bail; got: {msg}",
552        );
553        assert!(
554            msg.contains("--strict refuses to start"),
555            "expected strict-prefixed message; got: {msg}",
556        );
557    }
558
559    /// RFC-0010 §1.5.1: explicit opt-in lets file-only CA pass `--strict`.
560    /// Other startup paths may still fail; this test only asserts the CA
561    /// gate does NOT contribute to the bail message.
562    #[tokio::test]
563    async fn strict_passes_ca_gate_when_file_ca_with_opt_in() {
564        let mut args = args_satisfying_existing_strict_gates(true);
565        args.fleet_ca_key = Some(PathBuf::from("/dev/null"));
566        args.allow_file_ca_key = true;
567        let err = serve(args)
568            .await
569            .err()
570            .map(|e| format!("{e}"))
571            .unwrap_or_default();
572        assert!(
573            !err.contains("--tpm-ca-pubkey-raw"),
574            "opt-in should bypass the CA gate; got: {err}",
575        );
576    }
577
578    /// RFC-0010 §1.5.1: TPM backend satisfies the CA gate without opt-in.
579    #[tokio::test]
580    async fn strict_passes_ca_gate_when_tpm_configured() {
581        let mut args = args_satisfying_existing_strict_gates(true);
582        args.tpm_ca_pubkey_raw = Some(PathBuf::from("/dev/null"));
583        args.tpm_ca_sign_wrapper = Some(PathBuf::from("/dev/null"));
584        // fleet_ca_key intentionally None: TPM is the only backend.
585        let err = serve(args)
586            .await
587            .err()
588            .map(|e| format!("{e}"))
589            .unwrap_or_default();
590        assert!(
591            !err.contains("--tpm-ca-pubkey-raw"),
592            "TPM backend should satisfy the gate; got: {err}",
593        );
594    }
595}
596
597#[cfg(test)]
598mod shutdown_tests {
599    use super::*;
600    use std::time::Duration;
601
602    #[tokio::test]
603    async fn drain_returns_ok_when_tasks_exit_promptly() {
604        let cancel = CancellationToken::new();
605        let handles: Vec<_> = (0..3)
606            .map(|_| {
607                let c = cancel.clone();
608                tokio::spawn(async move {
609                    c.cancelled().await;
610                })
611            })
612            .collect();
613
614        cancel.cancel();
615        drain_background_tasks(handles)
616            .await
617            .expect("tasks should drain cleanly");
618    }
619
620    #[tokio::test(start_paused = true)]
621    async fn drain_bails_when_task_ignores_cancel() {
622        let cancel = CancellationToken::new();
623        let stuck = tokio::spawn(async {
624            tokio::time::sleep(TASK_SHUTDOWN_DEADLINE + Duration::from_secs(60)).await;
625        });
626        let handles = vec![stuck];
627
628        cancel.cancel();
629        let drain = tokio::spawn(async move { drain_background_tasks(handles).await });
630        tokio::time::advance(TASK_SHUTDOWN_DEADLINE + Duration::from_secs(1)).await;
631        let err = drain.await.unwrap().unwrap_err();
632        assert!(
633            err.to_string().contains("forcing exit"),
634            "expected force-exit message; got: {err}",
635        );
636    }
637
638    #[tokio::test(start_paused = true)]
639    async fn cancel_token_unblocks_select_loop() {
640        let cancel = CancellationToken::new();
641        let task_cancel = cancel.clone();
642        let handle = tokio::spawn(async move {
643            let mut ticker = tokio::time::interval(Duration::from_secs(3600));
644            ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
645            loop {
646                tokio::select! {
647                    _ = task_cancel.cancelled() => return,
648                    _ = ticker.tick() => {}
649                }
650            }
651        });
652
653        tokio::task::yield_now().await;
654        cancel.cancel();
655        tokio::time::timeout(Duration::from_secs(5), handle)
656            .await
657            .expect("task should exit on cancel within 5s")
658            .expect("task should not panic");
659    }
660}