1use std::collections::HashMap;
13use std::sync::Mutex;
14
15use anyhow::{Context, Result};
16use chrono::{DateTime, Utc};
17use nixfleet_proto::OnHealthFailure;
18use nixfleet_state_machine::{HostRolloutState, HostState, ProbeRecord};
19use rusqlite::{Connection, OptionalExtension, params};
20
21pub struct HostRolloutRecords<'a> {
22 pub(super) conn: &'a Mutex<Connection>,
23}
24
25impl<'a> HostRolloutRecords<'a> {
26 pub fn upsert(&self, state: &HostRolloutState) -> Result<()> {
29 let conn = super::lock_conn(self.conn)?;
30 upsert_inner(&conn, state)
31 }
32
33 pub fn load(&self, rollout_id: &str, hostname: &str) -> Result<Option<HostRolloutState>> {
36 let conn = super::lock_conn(self.conn)?;
37 load_inner(&conn, rollout_id, hostname)
38 }
39
40 pub fn all_for_rollout(&self, rollout_id: &str) -> Result<Vec<HostRolloutState>> {
43 let conn = super::lock_conn(self.conn)?;
44 let mut stmt = conn.prepare(
45 "SELECT rollout_id, hostname, channel, state, target_closure,
46 current_closure_at_dispatch, current_closure, reverted_to,
47 dispatched_at, dispatch_acked_at, activation_started_at,
48 activation_completed_at, activation_failed_at,
49 probe_observed_first_at, probe_failure_first_at,
50 soak_due_at, converged_at, failed_at, policy_applied,
51 reverted_at, probes_json, last_event_seq
52 FROM host_rollout_records
53 WHERE rollout_id = ?1",
54 )?;
55 let rows = stmt.query_map(params![rollout_id], row_to_state)?;
56 let mut out = Vec::new();
57 for r in rows {
58 out.push(r?);
59 }
60 Ok(out)
61 }
62
63 pub fn active_for_host(&self, hostname: &str) -> Result<Vec<HostRolloutState>> {
81 let conn = super::lock_conn(self.conn)?;
82 let mut stmt = conn.prepare(
83 "SELECT rollout_id, hostname, channel, state, target_closure,
84 current_closure_at_dispatch, current_closure, reverted_to,
85 dispatched_at, dispatch_acked_at, activation_started_at,
86 activation_completed_at, activation_failed_at,
87 probe_observed_first_at, probe_failure_first_at,
88 soak_due_at, converged_at, failed_at, policy_applied,
89 reverted_at, probes_json, last_event_seq
90 FROM host_rollout_records
91 WHERE hostname = ?1
92 AND state IN ('Pending', 'Activating', 'Deferred', 'Soaking', 'Failed')",
93 )?;
94 let rows = stmt.query_map(params![hostname], row_to_state)?;
95 let mut out = Vec::new();
96 for r in rows {
97 out.push(r?);
98 }
99 Ok(out)
100 }
101}
102
103fn upsert_inner(conn: &Connection, s: &HostRolloutState) -> Result<()> {
104 let probes_json =
105 serde_json::to_string(&s.probes).context("serialize probes_json for upsert")?;
106 let policy_applied_db = s.policy_applied.map(policy_to_db);
107
108 conn.execute(
109 "INSERT INTO host_rollout_records (
110 rollout_id, hostname, channel, state,
111 target_closure, current_closure_at_dispatch, current_closure, reverted_to,
112 dispatched_at, dispatch_acked_at, activation_started_at,
113 activation_completed_at, activation_failed_at,
114 probe_observed_first_at, probe_failure_first_at,
115 soak_due_at, converged_at, failed_at, policy_applied,
116 reverted_at, probes_json, last_event_seq
117 ) VALUES (
118 ?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15,
119 ?16, ?17, ?18, ?19, ?20, ?21, ?22
120 )
121 ON CONFLICT (rollout_id, hostname) DO UPDATE SET
122 channel = excluded.channel,
123 state = excluded.state,
124 target_closure = excluded.target_closure,
125 current_closure_at_dispatch = excluded.current_closure_at_dispatch,
126 current_closure = excluded.current_closure,
127 reverted_to = excluded.reverted_to,
128 dispatched_at = excluded.dispatched_at,
129 dispatch_acked_at = excluded.dispatch_acked_at,
130 activation_started_at = excluded.activation_started_at,
131 activation_completed_at = excluded.activation_completed_at,
132 activation_failed_at = excluded.activation_failed_at,
133 probe_observed_first_at = excluded.probe_observed_first_at,
134 probe_failure_first_at = excluded.probe_failure_first_at,
135 soak_due_at = excluded.soak_due_at,
136 converged_at = excluded.converged_at,
137 failed_at = excluded.failed_at,
138 policy_applied = excluded.policy_applied,
139 reverted_at = excluded.reverted_at,
140 probes_json = excluded.probes_json,
141 last_event_seq = excluded.last_event_seq",
142 params![
143 s.rollout_id,
144 s.hostname,
145 s.channel,
146 state_to_db(s.state),
147 s.target_closure,
148 s.current_closure_at_dispatch,
149 s.current_closure,
150 s.reverted_to,
151 s.dispatched_at.to_rfc3339(),
152 s.dispatch_acked_at.map(|t: DateTime<Utc>| t.to_rfc3339()),
153 s.activation_started_at
154 .map(|t: DateTime<Utc>| t.to_rfc3339()),
155 s.activation_completed_at
156 .map(|t: DateTime<Utc>| t.to_rfc3339()),
157 s.activation_failed_at
158 .map(|t: DateTime<Utc>| t.to_rfc3339()),
159 s.probe_observed_first_at
160 .map(|t: DateTime<Utc>| t.to_rfc3339()),
161 s.probe_failure_first_at
162 .map(|t: DateTime<Utc>| t.to_rfc3339()),
163 s.soak_due_at.map(|t: DateTime<Utc>| t.to_rfc3339()),
164 s.converged_at.map(|t: DateTime<Utc>| t.to_rfc3339()),
165 s.failed_at.map(|t: DateTime<Utc>| t.to_rfc3339()),
166 policy_applied_db,
167 s.reverted_at.map(|t: DateTime<Utc>| t.to_rfc3339()),
168 probes_json,
169 s.last_event_seq as i64,
170 ],
171 )
172 .context("upsert host_rollout_records")?;
173 Ok(())
174}
175
176fn load_inner(
177 conn: &Connection,
178 rollout_id: &str,
179 hostname: &str,
180) -> Result<Option<HostRolloutState>> {
181 conn.query_row(
182 "SELECT rollout_id, hostname, channel, state, target_closure,
183 current_closure_at_dispatch, current_closure, reverted_to,
184 dispatched_at, dispatch_acked_at, activation_started_at,
185 activation_completed_at, activation_failed_at,
186 probe_observed_first_at, probe_failure_first_at,
187 soak_due_at, converged_at, failed_at, policy_applied,
188 reverted_at, probes_json, last_event_seq
189 FROM host_rollout_records
190 WHERE rollout_id = ?1 AND hostname = ?2",
191 params![rollout_id, hostname],
192 row_to_state,
193 )
194 .optional()
195 .context("load host_rollout_records")
196}
197
198fn row_to_state(row: &rusqlite::Row<'_>) -> rusqlite::Result<HostRolloutState> {
199 let probes_json: String = row.get(20)?;
200 let probes: HashMap<String, ProbeRecord> = serde_json::from_str(&probes_json).map_err(|e| {
201 rusqlite::Error::FromSqlConversionFailure(20, rusqlite::types::Type::Text, Box::new(e))
202 })?;
203
204 Ok(HostRolloutState {
205 rollout_id: row.get(0)?,
206 hostname: row.get(1)?,
207 channel: row.get(2)?,
208 state: state_from_db(&row.get::<_, String>(3)?).map_err(|e| {
209 rusqlite::Error::FromSqlConversionFailure(
210 3,
211 rusqlite::types::Type::Text,
212 format!("unknown state: {e}").into(),
213 )
214 })?,
215 target_closure: row.get(4)?,
216 current_closure_at_dispatch: row.get(5)?,
217 current_closure: row.get(6)?,
218 reverted_to: row.get(7)?,
219 dispatched_at: parse_rfc3339_required(row, 8, "dispatched_at")?,
220 dispatch_acked_at: parse_rfc3339_optional(row, 9)?,
221 activation_started_at: parse_rfc3339_optional(row, 10)?,
222 activation_completed_at: parse_rfc3339_optional(row, 11)?,
223 activation_failed_at: parse_rfc3339_optional(row, 12)?,
224 probe_observed_first_at: parse_rfc3339_optional(row, 13)?,
225 probe_failure_first_at: parse_rfc3339_optional(row, 14)?,
226 soak_due_at: parse_rfc3339_optional(row, 15)?,
227 converged_at: parse_rfc3339_optional(row, 16)?,
228 failed_at: parse_rfc3339_optional(row, 17)?,
229 policy_applied: row
230 .get::<_, Option<String>>(18)?
231 .map(|s| {
232 policy_from_db(&s).map_err(|e| {
233 rusqlite::Error::FromSqlConversionFailure(
234 18,
235 rusqlite::types::Type::Text,
236 format!("unknown policy_applied: {e}").into(),
237 )
238 })
239 })
240 .transpose()?,
241 reverted_at: parse_rfc3339_optional(row, 19)?,
242 probes,
243 last_event_seq: row.get::<_, i64>(21)? as u64,
244 })
245}
246
247fn parse_rfc3339_required(
248 row: &rusqlite::Row<'_>,
249 idx: usize,
250 field: &'static str,
251) -> rusqlite::Result<DateTime<Utc>> {
252 let s: String = row.get(idx)?;
253 DateTime::parse_from_rfc3339(&s)
254 .map(|dt| dt.with_timezone(&Utc))
255 .map_err(|e| {
256 rusqlite::Error::FromSqlConversionFailure(
257 idx,
258 rusqlite::types::Type::Text,
259 format!("parse {field}: {e}").into(),
260 )
261 })
262}
263
264fn parse_rfc3339_optional(
265 row: &rusqlite::Row<'_>,
266 idx: usize,
267) -> rusqlite::Result<Option<DateTime<Utc>>> {
268 row.get::<_, Option<String>>(idx)?
269 .map(|s| {
270 DateTime::parse_from_rfc3339(&s)
271 .map(|dt| dt.with_timezone(&Utc))
272 .map_err(|e| {
273 rusqlite::Error::FromSqlConversionFailure(
274 idx,
275 rusqlite::types::Type::Text,
276 format!("parse rfc3339: {e}").into(),
277 )
278 })
279 })
280 .transpose()
281}
282
283fn state_to_db(s: HostState) -> &'static str {
284 match s {
285 HostState::Pending => "Pending",
286 HostState::Activating => "Activating",
287 HostState::Deferred => "Deferred",
288 HostState::Soaking => "Soaking",
289 HostState::Converged => "Converged",
290 HostState::Failed => "Failed",
291 HostState::Reverted => "Reverted",
292 }
293}
294
295fn state_from_db(s: &str) -> Result<HostState, String> {
296 match s {
297 "Pending" => Ok(HostState::Pending),
298 "Activating" => Ok(HostState::Activating),
299 "Deferred" => Ok(HostState::Deferred),
300 "Soaking" => Ok(HostState::Soaking),
301 "Converged" => Ok(HostState::Converged),
302 "Failed" => Ok(HostState::Failed),
303 "Reverted" => Ok(HostState::Reverted),
304 other => Err(other.to_string()),
305 }
306}
307
308fn policy_to_db(p: OnHealthFailure) -> &'static str {
309 match p {
310 OnHealthFailure::Halt => "halt",
311 OnHealthFailure::RollbackAndHalt => "rollback-and-halt",
312 }
313}
314
315fn policy_from_db(s: &str) -> Result<OnHealthFailure, String> {
316 match s {
317 "halt" => Ok(OnHealthFailure::Halt),
318 "rollback-and-halt" => Ok(OnHealthFailure::RollbackAndHalt),
319 other => Err(other.to_string()),
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::db::Db;
327 use chrono::{Duration, TimeZone};
328
329 fn t0() -> DateTime<Utc> {
330 Utc.with_ymd_and_hms(2026, 5, 16, 1, 0, 0).unwrap()
331 }
332
333 fn fresh_db() -> Db {
334 let db = Db::open_in_memory().unwrap();
335 db.migrate().unwrap();
336 db
337 }
338
339 #[test]
340 fn upsert_and_load_round_trip() {
341 let db = fresh_db();
342 let table = HostRolloutRecords { conn: &db.conn };
343
344 let mut s = HostRolloutState::new_pending(
345 "r1".into(),
346 "h1".into(),
347 "stable".into(),
348 "abc123".into(),
349 t0(),
350 t0() + Duration::minutes(5),
351 );
352 s.policy_applied = Some(OnHealthFailure::RollbackAndHalt);
353 s.last_event_seq = 7;
354
355 table.upsert(&s).unwrap();
356 let loaded = table.load("r1", "h1").unwrap().unwrap();
357 assert_eq!(loaded.state, HostState::Pending);
358 assert_eq!(loaded.target_closure, "abc123");
359 assert_eq!(loaded.last_event_seq, 7);
360 assert_eq!(
361 loaded.policy_applied,
362 Some(OnHealthFailure::RollbackAndHalt)
363 );
364 assert_eq!(loaded.dispatched_at, s.dispatched_at);
365 assert_eq!(loaded.soak_due_at, s.soak_due_at);
366 }
367
368 #[test]
369 fn upsert_overwrites_state_transition() {
370 let db = fresh_db();
371 let table = HostRolloutRecords { conn: &db.conn };
372
373 let mut s = HostRolloutState::new_pending(
374 "r1".into(),
375 "h1".into(),
376 "stable".into(),
377 "abc123".into(),
378 t0(),
379 t0() + Duration::minutes(5),
380 );
381 table.upsert(&s).unwrap();
382
383 s.state = HostState::Activating;
384 s.dispatch_acked_at = Some(t0() + Duration::seconds(1));
385 s.last_event_seq = 1;
386 table.upsert(&s).unwrap();
387
388 let loaded = table.load("r1", "h1").unwrap().unwrap();
389 assert_eq!(loaded.state, HostState::Activating);
390 assert_eq!(loaded.last_event_seq, 1);
391 }
392
393 #[test]
394 fn load_missing_returns_none() {
395 let db = fresh_db();
396 let table = HostRolloutRecords { conn: &db.conn };
397 let got = table.load("nope", "nope").unwrap();
398 assert!(got.is_none());
399 }
400
401 #[test]
402 fn all_for_rollout_returns_multiple_hosts() {
403 let db = fresh_db();
404 let table = HostRolloutRecords { conn: &db.conn };
405 for host in ["h1", "h2", "h3"] {
406 let s = HostRolloutState::new_pending(
407 "r1".into(),
408 host.into(),
409 "stable".into(),
410 "abc123".into(),
411 t0(),
412 t0() + Duration::minutes(5),
413 );
414 table.upsert(&s).unwrap();
415 }
416 let got = table.all_for_rollout("r1").unwrap();
417 assert_eq!(got.len(), 3);
418 }
419
420 #[test]
428 fn active_for_host_includes_failed_excludes_terminal() {
429 let db = fresh_db();
430 let table = HostRolloutRecords { conn: &db.conn };
431
432 let states_and_expected = [
433 (HostState::Pending, true),
434 (HostState::Activating, true),
435 (HostState::Deferred, true),
436 (HostState::Soaking, true),
437 (HostState::Failed, true),
438 (HostState::Converged, false),
439 (HostState::Reverted, false),
440 ];
441
442 for (idx, (state, _)) in states_and_expected.iter().enumerate() {
443 let rollout_id = format!("r{idx}");
444 let mut s = HostRolloutState::new_pending(
445 rollout_id.into(),
446 "h1".into(),
447 "stable".into(),
448 "abc123".into(),
449 t0(),
450 t0() + Duration::minutes(5),
451 );
452 s.state = *state;
453 if matches!(state, HostState::Soaking | HostState::Converged) {
454 s.current_closure = Some("abc123".into());
455 s.activation_completed_at = Some(t0() + Duration::seconds(5));
456 }
457 if matches!(state, HostState::Failed | HostState::Reverted) {
458 s.failed_at = Some(t0() + Duration::seconds(125));
459 s.current_closure_at_dispatch = Some("prior-closure".into());
460 }
461 if matches!(state, HostState::Reverted) {
462 s.reverted_at = Some(t0() + Duration::seconds(135));
463 s.reverted_to = Some("prior-closure".into());
464 s.current_closure = Some("prior-closure".into());
465 }
466 if matches!(state, HostState::Converged) {
467 s.converged_at = Some(t0() + Duration::minutes(6));
468 }
469 table.upsert(&s).unwrap();
470 }
471
472 let returned = table.active_for_host("h1").unwrap();
473 let returned_states: Vec<HostState> = returned.iter().map(|r| r.state).collect();
474
475 for (state, expected_included) in states_and_expected {
476 let included = returned_states.contains(&state);
477 assert_eq!(
478 included, expected_included,
479 "state {state:?}: expected included={expected_included}, got {included}",
480 );
481 }
482 }
483
484 #[test]
485 fn check_constraint_rejects_invalid_state() {
486 let db = fresh_db();
487 let conn = super::super::lock_conn(&db.conn).unwrap();
488 let err = conn
490 .execute(
491 "INSERT INTO host_rollout_records (
492 rollout_id, hostname, channel, state,
493 target_closure, dispatched_at, probes_json, last_event_seq
494 ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, '{}', 0)",
495 params!["r1", "h1", "stable", "Healthy", "abc", t0().to_rfc3339()],
496 )
497 .unwrap_err();
498 let s = format!("{err:?}");
499 assert!(s.contains("CHECK"), "expected CHECK violation, got {s}");
500 }
501}