1use std::path::{Path, PathBuf};
8
9use anyhow::{Context, Result};
10use chrono::Utc;
11use nixfleet_proto::{RolloutId, RolloutManifest, TrustConfig};
12use nixfleet_reconciler::{
13 VerifiedFleet, VerifiedRolloutManifest, canonical_hash_from_bytes, verify_artifact,
14};
15
16pub const DEFAULT_TRUST_PATH: &str = "/etc/nixfleet/agent/trust.json";
19
20#[derive(Debug)]
21pub enum ManifestError {
22 Missing(String),
23 VerifyFailed(String),
24 Mismatch(String),
25}
26
27impl ManifestError {
28 pub fn reason(&self) -> &str {
29 match self {
30 ManifestError::Missing(s) => s,
31 ManifestError::VerifyFailed(s) => s,
32 ManifestError::Mismatch(s) => s,
33 }
34 }
35}
36
37pub const DEFAULT_FRESHNESS_WINDOW_SECS: u64 = 3600;
40
41pub struct ManifestCache {
42 rollouts_dir: PathBuf,
43 fleet_dir: PathBuf,
44 trust_path: PathBuf,
45 freshness_window: std::time::Duration,
46}
47
48impl ManifestCache {
49 pub fn new(state_dir: &Path, trust_path: &Path) -> Self {
50 Self::new_with_freshness(
51 state_dir,
52 trust_path,
53 std::time::Duration::from_secs(DEFAULT_FRESHNESS_WINDOW_SECS),
54 )
55 }
56
57 pub fn new_with_freshness(
61 state_dir: &Path,
62 trust_path: &Path,
63 freshness_window: std::time::Duration,
64 ) -> Self {
65 Self {
66 rollouts_dir: state_dir.join("rollouts"),
67 fleet_dir: state_dir.join("fleet"),
68 trust_path: trust_path.to_path_buf(),
69 freshness_window,
70 }
71 }
72
73 pub fn new_default(state_dir: &Path) -> Self {
78 Self::new(state_dir, Path::new(DEFAULT_TRUST_PATH))
79 }
80
81 fn manifest_path(&self, rollout_id: &str) -> PathBuf {
82 self.rollouts_dir.join(format!("{rollout_id}.json"))
83 }
84
85 fn signature_path(&self, rollout_id: &str) -> PathBuf {
86 self.rollouts_dir.join(format!("{rollout_id}.json.sig"))
87 }
88
89 fn fleet_path(&self) -> PathBuf {
90 self.fleet_dir.join("fleet.resolved.json")
91 }
92
93 fn fleet_sig_path(&self) -> PathBuf {
94 self.fleet_dir.join("fleet.resolved.json.sig")
95 }
96
97 pub fn read_cached_bytes(&self, rollout_id: &str) -> Option<(Vec<u8>, Vec<u8>)> {
99 let manifest = std::fs::read(self.manifest_path(rollout_id)).ok()?;
100 let sig = std::fs::read(self.signature_path(rollout_id)).ok()?;
101 Some((manifest, sig))
102 }
103
104 fn load_trust_roots(
105 &self,
106 now: chrono::DateTime<Utc>,
107 ) -> Result<(
108 Vec<nixfleet_proto::TrustedPubkey>,
109 Option<chrono::DateTime<Utc>>,
110 )> {
111 let raw = std::fs::read_to_string(&self.trust_path)
112 .with_context(|| format!("read trust file {}", self.trust_path.display()))?;
113 let trust: TrustConfig = serde_json::from_str(&raw).context("parse trust file")?;
114 Ok((
115 trust.ci_release_key.active_keys_at(now),
116 trust.ci_release_key.reject_before,
117 ))
118 }
119
120 fn validate_rollout_id_for_path(rollout_id: &str) -> Result<(), ManifestError> {
126 if rollout_id.contains('/') || rollout_id.contains("..") {
127 return Err(ManifestError::Mismatch(format!(
128 "rollout_id {rollout_id:?} contains path-traversal characters"
129 )));
130 }
131 Ok(())
132 }
133
134 fn verify_bytes(
135 &self,
136 manifest_bytes: &[u8],
137 signature_bytes: &[u8],
138 advertised_rollout_id: &str,
139 ) -> Result<VerifiedRolloutManifest, ManifestError> {
140 let now = Utc::now();
145 let (trusted_keys, reject_before) = self
146 .load_trust_roots(now)
147 .map_err(|err| ManifestError::VerifyFailed(format!("load trust roots: {err:#}")))?;
148 let window = self.freshness_window;
149 let verified = nixfleet_reconciler::verify_rollout_manifest(
150 manifest_bytes,
151 signature_bytes,
152 &trusted_keys,
153 now,
154 window,
155 reject_before,
156 )
157 .map_err(|err| ManifestError::VerifyFailed(format!("{err:?}")))?;
158
159 Self::assert_rollout_id_matches(verified.inner(), advertised_rollout_id)?;
160 Ok(verified)
161 }
162
163 fn assert_rollout_id_matches(
172 manifest: &RolloutManifest,
173 advertised_rollout_id: &str,
174 ) -> Result<(), ManifestError> {
175 let parsed = RolloutId::new(&manifest.channel, &manifest.channel_ref);
176 if parsed.as_str() != advertised_rollout_id {
177 return Err(ManifestError::Mismatch(format!(
178 "advertised rolloutId {advertised} != parsed RolloutId {parsed}",
179 advertised = advertised_rollout_id,
180 parsed = parsed.as_str(),
181 )));
182 }
183 Ok(())
184 }
185
186 fn assert_membership(
187 manifest: &RolloutManifest,
188 hostname: &str,
189 wave_index: u32,
190 ) -> Result<(), ManifestError> {
191 let in_set = manifest
192 .host_set
193 .iter()
194 .any(|h| h.hostname == hostname && h.wave_index == wave_index);
195 if !in_set {
196 return Err(ManifestError::Mismatch(format!(
197 "(hostname={hostname}, wave_index={wave_index}) not in manifest.host_set"
198 )));
199 }
200 Ok(())
201 }
202
203 fn assert_target_closure(
209 manifest: &RolloutManifest,
210 hostname: &str,
211 expected_target_closure: &str,
212 ) -> Result<(), ManifestError> {
213 let host = manifest
214 .host_set
215 .iter()
216 .find(|h| h.hostname == hostname)
217 .ok_or_else(|| {
218 ManifestError::Mismatch(format!("hostname {hostname:?} not in manifest.host_set"))
219 })?;
220 if host.target_closure != expected_target_closure {
221 return Err(ManifestError::Mismatch(format!(
222 "dispatch target_closure {dispatched:?} != manifest target_closure {manifest_value:?} for hostname {hostname:?}",
223 dispatched = expected_target_closure,
224 manifest_value = host.target_closure,
225 )));
226 }
227 Ok(())
228 }
229
230 fn write_cache(&self, rollout_id: &str, manifest_bytes: &[u8], sig_bytes: &[u8]) -> Result<()> {
231 std::fs::create_dir_all(&self.rollouts_dir).with_context(|| {
232 format!("create rollouts cache dir {}", self.rollouts_dir.display())
233 })?;
234 std::fs::write(self.manifest_path(rollout_id), manifest_bytes)
235 .with_context(|| format!("write {}", self.manifest_path(rollout_id).display()))?;
236 std::fs::write(self.signature_path(rollout_id), sig_bytes)
237 .with_context(|| format!("write {}", self.signature_path(rollout_id).display()))?;
238 Ok(())
239 }
240
241 pub async fn fetch_or_load(
251 &self,
252 client: &reqwest::Client,
253 cp_url: &str,
254 rollout_id: &str,
255 ) -> Result<VerifiedRolloutManifest, ManifestError> {
256 Self::validate_rollout_id_for_path(rollout_id)?;
257
258 if let Some((manifest_bytes, sig_bytes)) = self.read_cached_bytes(rollout_id) {
259 match self.verify_bytes(&manifest_bytes, &sig_bytes, rollout_id) {
260 Ok(verified) => return Ok(verified),
261 Err(err) => {
262 tracing::info!(
263 target: "agent_manifest_cache",
264 rollout_id = %rollout_id,
265 error = %err.reason(),
266 "cached rollout manifest failed verification; falling through to fetch",
267 );
268 }
270 }
271 }
272
273 let base = cp_url.trim_end_matches('/');
274 let manifest_url = format!("{base}/v1/rollouts/{rollout_id}");
275 let sig_url = format!("{base}/v1/rollouts/{rollout_id}/sig");
276
277 let manifest_bytes = fetch(client, &manifest_url).await?;
278 let sig_bytes = fetch(client, &sig_url).await?;
279
280 let verified = self.verify_bytes(&manifest_bytes, &sig_bytes, rollout_id)?;
281
282 if let Err(err) = self.write_cache(rollout_id, &manifest_bytes, &sig_bytes) {
283 tracing::warn!(
284 rollout_id = %rollout_id,
285 error = %err,
286 "manifest cache: write-through failed (will refetch next checkin)",
287 );
288 }
289
290 Ok(verified)
291 }
292
293 pub async fn ensure(
298 &self,
299 client: &reqwest::Client,
300 cp_url: &str,
301 rollout_id: &str,
302 hostname: &str,
303 wave_index: u32,
304 ) -> Result<VerifiedRolloutManifest, ManifestError> {
305 let verified = self.fetch_or_load(client, cp_url, rollout_id).await?;
306 Self::assert_membership(verified.inner(), hostname, wave_index)?;
307 Ok(verified)
308 }
309
310 pub async fn ensure_for_dispatch(
315 &self,
316 client: &reqwest::Client,
317 cp_url: &str,
318 rollout_id: &str,
319 hostname: &str,
320 expected_target_closure: &str,
321 ) -> Result<VerifiedRolloutManifest, ManifestError> {
322 let verified = self.fetch_or_load(client, cp_url, rollout_id).await?;
323 Self::assert_target_closure(verified.inner(), hostname, expected_target_closure)?;
324 Ok(verified)
325 }
326
327 fn read_cached_fleet_bytes(&self) -> Option<(Vec<u8>, Vec<u8>)> {
328 let artifact = std::fs::read(self.fleet_path()).ok()?;
329 let sig = std::fs::read(self.fleet_sig_path()).ok()?;
330 Some((artifact, sig))
331 }
332
333 fn write_fleet_cache(&self, artifact_bytes: &[u8], sig_bytes: &[u8]) -> Result<()> {
334 std::fs::create_dir_all(&self.fleet_dir)
335 .with_context(|| format!("create fleet cache dir {}", self.fleet_dir.display()))?;
336 std::fs::write(self.fleet_path(), artifact_bytes)
337 .with_context(|| format!("write {}", self.fleet_path().display()))?;
338 std::fs::write(self.fleet_sig_path(), sig_bytes)
339 .with_context(|| format!("write {}", self.fleet_sig_path().display()))?;
340 Ok(())
341 }
342
343 fn verify_fleet_bytes(
344 &self,
345 artifact_bytes: &[u8],
346 signature_bytes: &[u8],
347 ) -> Result<VerifiedFleet, ManifestError> {
348 let now = Utc::now();
349 let (trusted_keys, reject_before) = self
350 .load_trust_roots(now)
351 .map_err(|err| ManifestError::VerifyFailed(format!("load trust roots: {err:#}")))?;
352 let window = self.freshness_window;
353 verify_artifact(
354 artifact_bytes,
355 signature_bytes,
356 &trusted_keys,
357 now,
358 window,
359 reject_before,
360 )
361 .map_err(|err| ManifestError::VerifyFailed(format!("{err:?}")))
362 }
363
364 pub async fn fetch_or_load_fleet(
382 &self,
383 client: &reqwest::Client,
384 cp_url: &str,
385 ) -> Result<(VerifiedFleet, String), ManifestError> {
386 if let Some((artifact_bytes, sig_bytes)) = self.read_cached_fleet_bytes() {
387 match self.verify_fleet_bytes(&artifact_bytes, &sig_bytes) {
388 Ok(verified) => {
389 let hash = canonical_hash_from_bytes(&artifact_bytes).map_err(|err| {
390 ManifestError::Mismatch(format!("hash cached fleet: {err:?}"))
391 })?;
392 return Ok((verified, hash));
393 }
394 Err(err) => {
395 tracing::info!(
396 target: "agent_manifest_cache",
397 error = %err.reason(),
398 "cached fleet manifest failed verification; falling through to fetch",
399 );
400 }
402 }
403 }
404
405 let base = cp_url.trim_end_matches('/');
406 let artifact_url = format!("{base}/v1/fleet.resolved");
407 let sig_url = format!("{base}/v1/fleet.resolved/sig");
408
409 let artifact_bytes = fetch(client, &artifact_url).await?;
410 let sig_bytes = fetch(client, &sig_url).await?;
411
412 let verified = self.verify_fleet_bytes(&artifact_bytes, &sig_bytes)?;
413 let hash = canonical_hash_from_bytes(&artifact_bytes)
414 .map_err(|err| ManifestError::Mismatch(format!("hash fetched fleet: {err:?}")))?;
415
416 if let Err(err) = self.write_fleet_cache(&artifact_bytes, &sig_bytes) {
417 tracing::warn!(
418 error = %err,
419 "fleet cache: write-through failed (will refetch next tick)",
420 );
421 }
422
423 Ok((verified, hash))
424 }
425}
426
427async fn fetch(client: &reqwest::Client, url: &str) -> Result<Vec<u8>, ManifestError> {
428 let resp = client
429 .get(url)
430 .send()
431 .await
432 .map_err(|err| ManifestError::Missing(format!("GET {url}: {err}")))?;
433 let status = resp.status();
434 if status == reqwest::StatusCode::NOT_FOUND {
435 return Err(ManifestError::Missing(format!("404 from {url}")));
436 }
437 if !status.is_success() {
438 return Err(ManifestError::Missing(format!("{url}: {status}")));
439 }
440 let bytes = resp
441 .bytes()
442 .await
443 .map_err(|err| ManifestError::Missing(format!("read body {url}: {err}")))?;
444 Ok(bytes.to_vec())
445}
446
447#[cfg(test)]
448mod tests {
449 use super::*;
450 use nixfleet_proto::fleet_resolved::{HealthGate, Meta};
451 use nixfleet_proto::rollout_manifest::HostWave;
452
453 #[test]
454 fn manifest_error_variants_distinct_on_debug() {
455 let outcomes = [
456 format!("{:?}", ManifestError::Missing("x".into())),
457 format!("{:?}", ManifestError::VerifyFailed("x".into())),
458 format!("{:?}", ManifestError::Mismatch("x".into())),
459 ];
460 let unique: std::collections::HashSet<_> = outcomes.iter().collect();
461 assert_eq!(unique.len(), outcomes.len());
462 }
463
464 fn manifest_with(host_set: Vec<HostWave>) -> RolloutManifest {
465 RolloutManifest {
466 schema_version: 1,
467 display_name: "stable@abc1234".into(),
468 channel: "stable".into(),
469 channel_ref: "abc1234deadbeef".into(),
470 fleet_resolved_hash: "1111111111111111111111111111111111111111111111111111111111111111"
471 .into(),
472 host_set,
473 health_gate: HealthGate::default(),
474 disruption_budgets: Vec::new(),
475 meta: Meta {
476 schema_version: 1,
477 signed_at: None,
478 ci_commit: None,
479 signature_algorithm: None,
480 },
481 }
482 }
483
484 fn host_wave(hostname: &str, wave_index: u32, target_closure: &str) -> HostWave {
485 HostWave {
486 hostname: hostname.into(),
487 wave_index,
488 target_closure: target_closure.into(),
489 }
490 }
491
492 #[test]
493 fn assert_target_closure_passes_on_match() {
494 let m = manifest_with(vec![host_wave("h1", 0, "closure-A")]);
495 ManifestCache::assert_target_closure(&m, "h1", "closure-A").expect("match");
496 }
497
498 #[test]
499 fn assert_target_closure_fails_on_target_mismatch() {
500 let m = manifest_with(vec![host_wave("h1", 0, "closure-A")]);
501 let err = ManifestCache::assert_target_closure(&m, "h1", "closure-B")
502 .expect_err("target mismatch");
503 let msg = err.reason();
504 assert!(
505 msg.contains("closure-A"),
506 "expected manifest target in error: {msg}"
507 );
508 assert!(
509 msg.contains("closure-B"),
510 "expected dispatched target in error: {msg}"
511 );
512 }
513
514 #[test]
515 fn assert_target_closure_fails_when_hostname_not_in_set() {
516 let m = manifest_with(vec![host_wave("h1", 0, "closure-A")]);
517 let err = ManifestCache::assert_target_closure(&m, "h2", "closure-A")
518 .expect_err("hostname not in set");
519 let msg = err.reason();
520 assert!(msg.contains("h2"), "expected hostname in error: {msg}");
521 }
522
523 #[test]
524 fn validate_rollout_id_for_path_refuses_traversal() {
525 assert!(ManifestCache::validate_rollout_id_for_path("stable@abc1234").is_ok());
526 assert!(ManifestCache::validate_rollout_id_for_path("stable@abc/123").is_err());
527 assert!(ManifestCache::validate_rollout_id_for_path("../../../etc/passwd").is_err());
528 assert!(ManifestCache::validate_rollout_id_for_path("a..b").is_err());
529 }
530
531 #[test]
532 fn assert_rollout_id_matches_accepts_canonical_format() {
533 let m = manifest_with(vec![host_wave("h1", 0, "closure-A")]);
534 ManifestCache::assert_rollout_id_matches(&m, "stable@abc1234deadbeef")
536 .expect("canonical id matches");
537 }
538
539 #[test]
540 fn assert_rollout_id_matches_rejects_channel_only() {
541 let m = manifest_with(vec![host_wave("h1", 0, "closure-A")]);
542 let err = ManifestCache::assert_rollout_id_matches(&m, "stable")
543 .expect_err("channel-only rejected");
544 assert!(matches!(err, ManifestError::Mismatch(_)));
545 }
546
547 #[test]
548 fn assert_rollout_id_matches_rejects_channel_ref_only() {
549 let m = manifest_with(vec![host_wave("h1", 0, "closure-A")]);
550 let err = ManifestCache::assert_rollout_id_matches(&m, "abc1234deadbeef")
551 .expect_err("channel_ref-only rejected");
552 assert!(matches!(err, ManifestError::Mismatch(_)));
553 }
554
555 #[test]
556 fn fleet_path_and_sig_path_under_fleet_subdir() {
557 let state_dir = std::path::PathBuf::from("/tmp/nixfleet-agent-test");
558 let cache = ManifestCache::new(&state_dir, std::path::Path::new("/dev/null"));
559 assert_eq!(
560 cache.fleet_path(),
561 state_dir.join("fleet").join("fleet.resolved.json")
562 );
563 assert_eq!(
564 cache.fleet_sig_path(),
565 state_dir.join("fleet").join("fleet.resolved.json.sig")
566 );
567 }
568
569 #[test]
570 fn assert_rollout_id_matches_rejects_sha256_hash_format() {
571 let m = manifest_with(vec![host_wave("h1", 0, "closure-A")]);
577 let plausible_hash = "a".repeat(64);
578 let err = ManifestCache::assert_rollout_id_matches(&m, &plausible_hash)
579 .expect_err("sha256 hex rejected");
580 assert!(matches!(err, ManifestError::Mismatch(_)));
581 }
582
583 fn minimal_trust_json() -> String {
589 let zero_pub_b64 = "AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=";
592 format!(
593 r#"{{
594 "schemaVersion": 1,
595 "ciReleaseKey": {{
596 "current": {{
597 "algorithm": "ed25519",
598 "public": "{zero_pub_b64}"
599 }}
600 }}
601 }}"#
602 )
603 }
604
605 fn rt() -> tokio::runtime::Runtime {
606 tokio::runtime::Builder::new_current_thread()
607 .enable_all()
608 .build()
609 .expect("tokio runtime")
610 }
611
612 #[test]
613 fn fleet_cache_verify_failure_falls_through_to_fetch() {
614 let dir = tempfile::tempdir().expect("tempdir");
623 let trust_path = dir.path().join("trust.json");
624 std::fs::write(&trust_path, minimal_trust_json()).expect("write trust");
625 std::fs::create_dir_all(dir.path().join("fleet")).expect("mkdir fleet");
626 std::fs::write(
628 dir.path().join("fleet/fleet.resolved.json"),
629 br#"{"schemaVersion":1,"signedAt":"2020-01-01T00:00:00Z"}"#,
630 )
631 .expect("write cached artifact");
632 std::fs::write(dir.path().join("fleet/fleet.resolved.json.sig"), b"sig")
633 .expect("write cached sig");
634
635 let cache = ManifestCache::new(dir.path(), &trust_path);
636 let unreachable_cp = "http://127.0.0.1:1";
638 let client = reqwest::Client::new();
639
640 let err = rt()
641 .block_on(cache.fetch_or_load_fleet(&client, unreachable_cp))
642 .expect_err("fetch_or_load_fleet must error when both cache and CP fail");
643 assert!(
644 matches!(err, ManifestError::Missing(_)),
645 "post-fix MUST fall through to fetch when cache verify fails; \
646 error variant indicates which path returned. Got: {err:?}",
647 );
648 }
649
650 #[test]
651 fn rollout_manifest_cache_verify_failure_falls_through_to_fetch() {
652 let dir = tempfile::tempdir().expect("tempdir");
656 let trust_path = dir.path().join("trust.json");
657 std::fs::write(&trust_path, minimal_trust_json()).expect("write trust");
658 let rollout_id = "stable@abc1234deadbeef";
659 std::fs::create_dir_all(dir.path().join("rollouts")).expect("mkdir rollouts");
660 std::fs::write(
661 dir.path().join(format!("rollouts/{rollout_id}.json")),
662 br#"{"schemaVersion":1,"channel":"stable","channelRef":"abc1234deadbeef"}"#,
663 )
664 .expect("write cached manifest");
665 std::fs::write(
666 dir.path().join(format!("rollouts/{rollout_id}.json.sig")),
667 b"sig",
668 )
669 .expect("write cached sig");
670
671 let cache = ManifestCache::new(dir.path(), &trust_path);
672 let unreachable_cp = "http://127.0.0.1:1";
673 let client = reqwest::Client::new();
674
675 let err = rt()
676 .block_on(cache.fetch_or_load(&client, unreachable_cp, rollout_id))
677 .expect_err("fetch_or_load must error when both cache and CP fail");
678 assert!(
679 matches!(err, ManifestError::Missing(_)),
680 "post-fix MUST fall through to fetch when cache verify fails; \
681 error variant indicates which path returned. Got: {err:?}",
682 );
683 }
684}