nixfleet_control_plane/auth/
auth_cn.rs1use axum::extract::{Path, Request};
9use axum::http::StatusCode;
10use axum::middleware::Next;
11use axum::response::Response;
12use axum_server::accept::Accept;
13use axum_server::tls_rustls::RustlsAcceptor;
14use rustls_pki_types::CertificateDer;
15use std::collections::HashMap;
16use std::future::Future;
17use std::io;
18use std::pin::Pin;
19use tokio::io::{AsyncRead, AsyncWrite};
20use tokio_rustls::server::TlsStream;
21use x509_parser::prelude::*;
22
23#[derive(Clone, Debug, Default)]
25pub struct PeerCertificates {
26 chain: Vec<CertificateDer<'static>>,
28}
29
30impl PeerCertificates {
31 pub fn new(chain: Vec<CertificateDer<'static>>) -> Self {
32 Self { chain }
33 }
34
35 pub fn empty() -> Self {
36 Self { chain: Vec::new() }
37 }
38
39 pub fn is_present(&self) -> bool {
40 !self.chain.is_empty()
41 }
42
43 pub fn leaf(&self) -> Option<&CertificateDer<'static>> {
44 self.chain.first()
45 }
46
47 pub fn leaf_cn(&self) -> Option<String> {
48 let leaf = self.leaf()?;
49 let (_, cert) = X509Certificate::from_der(leaf.as_ref()).ok()?;
50 cert.subject()
53 .iter_common_name()
54 .next()
55 .and_then(|attr| attr.as_str().ok().map(String::from))
56 }
57
58 pub fn leaf_not_before(&self) -> Option<chrono::DateTime<chrono::Utc>> {
61 let leaf = self.leaf()?;
62 let (_, cert) = X509Certificate::from_der(leaf.as_ref()).ok()?;
63 let secs = cert.validity().not_before.timestamp();
64 chrono::DateTime::<chrono::Utc>::from_timestamp(secs, 0)
65 }
66}
67
68#[derive(Clone, Debug)]
69pub struct MtlsAcceptor<A = axum_server::accept::DefaultAcceptor> {
70 inner: RustlsAcceptor<A>,
71}
72
73impl MtlsAcceptor {
74 pub fn new(inner: RustlsAcceptor) -> Self {
75 Self { inner }
76 }
77}
78
79impl<I, S, A> Accept<I, S> for MtlsAcceptor<A>
80where
81 A: Accept<I, S> + Clone + Send + 'static,
82 A::Stream: AsyncRead + AsyncWrite + Unpin + Send,
83 A::Service: Send,
84 A::Future: Send,
85 I: Send + 'static,
86 S: Send + 'static,
87{
88 type Stream = TlsStream<A::Stream>;
89 type Service = PeerCertService<A::Service>;
90 type Future = Pin<Box<dyn Future<Output = io::Result<(Self::Stream, Self::Service)>> + Send>>;
91
92 fn accept(&self, stream: I, service: S) -> Self::Future {
93 let inner = self.inner.clone();
94 Box::pin(async move {
95 let (tls_stream, inner_service) = inner.accept(stream, service).await?;
96
97 let (_, server_conn) = tls_stream.get_ref();
98 let peer_certs = match server_conn.peer_certificates() {
99 Some(certs) if !certs.is_empty() => {
100 let owned: Vec<CertificateDer<'static>> =
101 certs.iter().map(|c| c.clone().into_owned()).collect();
102 PeerCertificates::new(owned)
103 }
104 _ => PeerCertificates::empty(),
105 };
106
107 Ok((tls_stream, PeerCertService::new(inner_service, peer_certs)))
108 })
109 }
110}
111
112#[derive(Clone, Debug)]
113pub struct PeerCertService<S> {
114 inner: S,
115 peer_certs: PeerCertificates,
116}
117
118impl<S> PeerCertService<S> {
119 fn new(inner: S, peer_certs: PeerCertificates) -> Self {
120 Self { inner, peer_certs }
121 }
122}
123
124impl<S, B> tower_service::Service<http::Request<B>> for PeerCertService<S>
125where
126 S: tower_service::Service<http::Request<B>>,
127{
128 type Response = S::Response;
129 type Error = S::Error;
130 type Future = S::Future;
131
132 fn poll_ready(
133 &mut self,
134 cx: &mut std::task::Context<'_>,
135 ) -> std::task::Poll<Result<(), Self::Error>> {
136 self.inner.poll_ready(cx)
137 }
138
139 fn call(&mut self, mut req: http::Request<B>) -> Self::Future {
140 req.extensions_mut().insert(self.peer_certs.clone());
141 self.inner.call(req)
142 }
143}
144
145pub async fn cn_matches_path_machine_id(
147 Path(params): Path<HashMap<String, String>>,
148 request: Request,
149 next: Next,
150) -> Result<Response, StatusCode> {
151 let Some(id) = params.get("id") else {
152 return Ok(next.run(request).await);
153 };
154
155 let Some(certs) = request.extensions().get::<PeerCertificates>() else {
156 return Ok(next.run(request).await);
157 };
158
159 if !certs.is_present() {
160 return Ok(next.run(request).await);
161 }
162
163 let cn = certs.leaf_cn().ok_or_else(|| {
164 tracing::warn!(
165 path_id = %id,
166 "Rejecting agent request: peer certificate has no CN"
167 );
168 StatusCode::FORBIDDEN
169 })?;
170
171 if cn != id.as_str() {
172 tracing::warn!(
173 cert_cn = %cn,
174 path_id = %id,
175 "Rejecting agent request: cert CN does not match path machine_id"
176 );
177 return Err(StatusCode::FORBIDDEN);
178 }
179
180 Ok(next.run(request).await)
181}