nixfleet_control_plane/server/
middleware.rs

1//! Auth + protocol middleware for the v1 router.
2
3use std::sync::Arc;
4
5use axum::Json;
6use axum::body::Body;
7use axum::http::{Request as HttpRequest, StatusCode, header};
8use axum::middleware::Next;
9use axum::response::{IntoResponse, Response};
10use nixfleet_proto::agent_wire::{PROTOCOL_MAJOR_VERSION, PROTOCOL_VERSION_HEADER};
11use serde_json::json;
12
13use crate::auth::auth_cn::PeerCertificates;
14
15use super::state::AppState;
16
17/// `Retry-After` hint advertised on 503 not-ready responses. Tracks
18/// `channel_refs_poll::POLL_INTERVAL` (60 s) loosely - agents spread
19/// their retries across the hint so the next poll cycle has time to
20/// complete before they all reconnect.
21const NOT_READY_RETRY_AFTER_SECS: u32 = 30;
22
23/// 401 on missing/revoked cert; re-enrolled certs (notBefore > revoked_before) pass.
24///
25/// LOADBEARING: revocation DB rows store the **short** hostname (the
26/// operator-declared form from fleet.nix), while the cert's CN is the
27/// **canonical** `agent-<machineId>.<suffix>` form. Look up by the
28/// canonicalized-down short hostname so the two sides match.
29pub(super) async fn require_cn(
30    state: &AppState,
31    peer_certs: &PeerCertificates,
32) -> Result<String, StatusCode> {
33    if !peer_certs.is_present() {
34        return Err(StatusCode::UNAUTHORIZED);
35    }
36    let cn = peer_certs.leaf_cn().ok_or(StatusCode::UNAUTHORIZED)?;
37
38    if let Some(db) = &state.db {
39        let machine_id = crate::auth::issuance::extract_machine_id(&cn, &state.agent_cn_suffix);
40        match db.revocations().cert_revoked_before(&machine_id) {
41            Ok(Some(revoked_before)) => {
42                let cert_nbf = peer_certs
43                    .leaf_not_before()
44                    .ok_or(StatusCode::UNAUTHORIZED)?;
45                if cert_nbf < revoked_before {
46                    tracing::warn!(
47                        cn = %cn,
48                        machine_id = %machine_id,
49                        cert_not_before = %cert_nbf.to_rfc3339(),
50                        revoked_before = %revoked_before.to_rfc3339(),
51                        "rejecting revoked cert"
52                    );
53                    return Err(StatusCode::UNAUTHORIZED);
54                }
55            }
56            Ok(None) => {}
57            Err(err) => {
58                tracing::error!(error = %err, "db cert_revoked_before failed");
59                return Err(StatusCode::INTERNAL_SERVER_ERROR);
60            }
61        }
62    }
63
64    Ok(cn)
65}
66
67/// Type-system witness that auth ran; private field prevents forgery in handler code.
68#[derive(Clone, Debug)]
69pub(crate) struct AuthenticatedCn(String);
70
71impl AuthenticatedCn {
72    pub(crate) fn as_str(&self) -> &str {
73        &self.0
74    }
75
76    pub(crate) fn into_string(self) -> String {
77        self.0
78    }
79}
80
81pub(super) async fn require_cn_layer(
82    state: Arc<AppState>,
83    mut req: HttpRequest<Body>,
84    next: Next,
85) -> Result<axum::response::Response, StatusCode> {
86    let peer_certs = req
87        .extensions()
88        .get::<PeerCertificates>()
89        .cloned()
90        .unwrap_or_default();
91    let cn = require_cn(&state, &peer_certs).await?;
92    req.extensions_mut().insert(AuthenticatedCn(cn));
93    Ok(next.run(req).await)
94}
95
96/// 503 with `Retry-After: 30` until `AppState::is_ready()` returns true.
97/// Applied to every `/v1/*` route so agents see a deterministic "come
98/// back later" signal instead of partial behaviour driven by stale or
99/// missing trust state. Health/version/metrics are routed outside
100/// `/v1/*` and stay unguarded so operators can scrape them while the
101/// daemon is still priming.
102pub(super) async fn require_ready_layer(
103    state: Arc<AppState>,
104    req: HttpRequest<Body>,
105    next: Next,
106) -> Response {
107    if state.is_ready() {
108        return next.run(req).await;
109    }
110
111    let body = Json(json!({
112        "error": "control plane not ready",
113        "reason": "awaiting first signed artifact",
114    }));
115    let mut response = (StatusCode::SERVICE_UNAVAILABLE, body).into_response();
116    if let Ok(value) = NOT_READY_RETRY_AFTER_SECS.to_string().parse() {
117        response.headers_mut().insert(header::RETRY_AFTER, value);
118    }
119    response
120}
121
122/// Forward-compat: missing header accepted; mismatched major -> 426. Strict mode rejects missing.
123pub(super) async fn protocol_version_middleware(
124    strict: bool,
125    req: HttpRequest<Body>,
126    next: Next,
127) -> Result<axum::response::Response, StatusCode> {
128    if let Some(value) = req.headers().get(PROTOCOL_VERSION_HEADER) {
129        match value.to_str().ok().and_then(|s| s.parse::<u32>().ok()) {
130            Some(v) if v == PROTOCOL_MAJOR_VERSION => Ok(next.run(req).await),
131            Some(v) => {
132                tracing::warn!(
133                    sent = v,
134                    expected = PROTOCOL_MAJOR_VERSION,
135                    "rejecting request with mismatched protocol major version"
136                );
137                Err(StatusCode::UPGRADE_REQUIRED)
138            }
139            None => {
140                tracing::warn!(
141                    raw = ?value,
142                    "X-Nixfleet-Protocol header malformed"
143                );
144                Err(StatusCode::BAD_REQUEST)
145            }
146        }
147    } else if strict {
148        tracing::warn!("rejecting request without X-Nixfleet-Protocol (strict mode)");
149        Err(StatusCode::BAD_REQUEST)
150    } else {
151        tracing::debug!("request without X-Nixfleet-Protocol - accepting (forward-compat)");
152        Ok(next.run(req).await)
153    }
154}