From 72497ccae7708513f5fb55bbd8495b74083e9b44 Mon Sep 17 00:00:00 2001 From: Affaan Mustafa Date: Tue, 24 Mar 2026 03:39:53 -0700 Subject: [PATCH] feat(ecc2): add tool risk scoring and actions --- ecc2/src/comms/mod.rs | 5 +- ecc2/src/config/mod.rs | 54 +++++++ ecc2/src/main.rs | 15 +- ecc2/src/observability/mod.rs | 279 +++++++++++++++++++++++++++++++--- ecc2/src/session/manager.rs | 2 +- ecc2/src/session/store.rs | 12 +- ecc2/src/tui/dashboard.rs | 32 +++- ecc2/src/worktree/mod.rs | 6 +- 8 files changed, 360 insertions(+), 45 deletions(-) diff --git a/ecc2/src/comms/mod.rs b/ecc2/src/comms/mod.rs index be176e96..8be89f2b 100644 --- a/ecc2/src/comms/mod.rs +++ b/ecc2/src/comms/mod.rs @@ -13,7 +13,10 @@ pub enum MessageType { /// Response to a query Response { answer: String }, /// Notification of completion - Completed { summary: String, files_changed: Vec }, + Completed { + summary: String, + files_changed: Vec, + }, /// Conflict detected (e.g., two agents editing the same file) Conflict { file: String, description: String }, } diff --git a/ecc2/src/config/mod.rs b/ecc2/src/config/mod.rs index 1e7eeab7..6abef4df 100644 --- a/ecc2/src/config/mod.rs +++ b/ecc2/src/config/mod.rs @@ -2,7 +2,16 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; use std::path::PathBuf; +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +#[serde(default)] +pub struct RiskThresholds { + pub review: f64, + pub confirm: f64, + pub block: f64, +} + #[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(default)] pub struct Config { pub db_path: PathBuf, pub worktree_root: PathBuf, @@ -12,6 +21,7 @@ pub struct Config { pub heartbeat_interval_secs: u64, pub default_agent: String, pub theme: Theme, + pub risk_thresholds: RiskThresholds, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -32,11 +42,18 @@ impl Default for Config { heartbeat_interval_secs: 30, default_agent: "claude".to_string(), theme: Theme::Dark, + risk_thresholds: Self::RISK_THRESHOLDS, } } } impl Config { + pub const RISK_THRESHOLDS: RiskThresholds = RiskThresholds { + review: 0.35, + confirm: 0.60, + block: 0.85, + }; + pub fn load() -> Result { let config_path = dirs::home_dir() .unwrap_or_else(|| PathBuf::from(".")) @@ -52,3 +69,40 @@ impl Config { } } } + +impl Default for RiskThresholds { + fn default() -> Self { + Config::RISK_THRESHOLDS + } +} + +#[cfg(test)] +mod tests { + use super::Config; + + #[test] + fn default_config_uses_default_risk_thresholds() { + let config = Config::default(); + + assert_eq!(config.risk_thresholds, Config::RISK_THRESHOLDS); + } + + #[test] + fn deserialization_defaults_risk_thresholds() { + let config: Config = toml::from_str( + r#" +db_path = "/tmp/ecc2.db" +worktree_root = "/tmp/ecc-worktrees" +max_parallel_sessions = 8 +max_parallel_worktrees = 6 +session_timeout_secs = 3600 +heartbeat_interval_secs = 30 +default_agent = "claude" +theme = "Dark" +"#, + ) + .expect("config should deserialize"); + + assert_eq!(config.risk_thresholds, Config::RISK_THRESHOLDS); + } +} diff --git a/ecc2/src/main.rs b/ecc2/src/main.rs index 850b7b49..afa50a2f 100644 --- a/ecc2/src/main.rs +++ b/ecc2/src/main.rs @@ -1,9 +1,9 @@ +mod comms; mod config; +mod observability; mod session; mod tui; mod worktree; -mod observability; -mod comms; use anyhow::Result; use clap::Parser; @@ -63,10 +63,13 @@ async fn main() -> Result<()> { Some(Commands::Dashboard) | None => { tui::app::run(db, cfg).await?; } - Some(Commands::Start { task, agent, worktree: use_worktree }) => { - let session_id = session::manager::create_session( - &db, &cfg, &task, &agent, use_worktree, - ).await?; + Some(Commands::Start { + task, + agent, + worktree: use_worktree, + }) => { + let session_id = + session::manager::create_session(&db, &cfg, &task, &agent, use_worktree).await?; println!("Session started: {session_id}"); } Some(Commands::Sessions) => { diff --git a/ecc2/src/observability/mod.rs b/ecc2/src/observability/mod.rs index 5f7a9645..ff5e7349 100644 --- a/ecc2/src/observability/mod.rs +++ b/ecc2/src/observability/mod.rs @@ -1,6 +1,7 @@ use anyhow::Result; use serde::{Deserialize, Serialize}; +use crate::config::RiskThresholds; use crate::session::store::StateStore; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -13,36 +14,203 @@ pub struct ToolCallEvent { pub risk_score: f64, } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RiskAssessment { + pub score: f64, + pub reasons: Vec, + pub suggested_action: SuggestedAction, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SuggestedAction { + Allow, + Review, + RequireConfirmation, + Block, +} + impl ToolCallEvent { - /// Compute risk score based on tool type and input patterns. - pub fn compute_risk(tool_name: &str, input: &str) -> f64 { - let mut score: f64 = 0.0; + /// Compute risk from the tool type and input characteristics. + pub fn compute_risk( + tool_name: &str, + input: &str, + thresholds: &RiskThresholds, + ) -> RiskAssessment { + let normalized_tool = tool_name.to_ascii_lowercase(); + let normalized_input = input.to_ascii_lowercase(); + let mut score = 0.0; + let mut reasons = Vec::new(); - // Destructive tools get higher base risk - match tool_name { - "Bash" => score += 0.3, - "Write" => score += 0.2, - "Edit" => score += 0.1, - _ => score += 0.05, + let (base_score, base_reason) = base_tool_risk(&normalized_tool); + score += base_score; + if let Some(reason) = base_reason { + reasons.push(reason.to_string()); } - // Dangerous patterns in bash commands - if tool_name == "Bash" { - if input.contains("rm -rf") || input.contains("--force") { - score += 0.4; - } - if input.contains("git push") || input.contains("git reset") { - score += 0.3; - } - if input.contains("sudo") || input.contains("chmod 777") { - score += 0.5; - } + let (file_sensitivity_score, file_sensitivity_reason) = + assess_file_sensitivity(&normalized_input); + score += file_sensitivity_score; + if let Some(reason) = file_sensitivity_reason { + reasons.push(reason); } - score.min(1.0) + let (blast_radius_score, blast_radius_reason) = assess_blast_radius(&normalized_input); + score += blast_radius_score; + if let Some(reason) = blast_radius_reason { + reasons.push(reason); + } + + let (irreversibility_score, irreversibility_reason) = + assess_irreversibility(&normalized_input); + score += irreversibility_score; + if let Some(reason) = irreversibility_reason { + reasons.push(reason); + } + + let score = score.clamp(0.0, 1.0); + let suggested_action = SuggestedAction::from_score(score, thresholds); + + RiskAssessment { + score, + reasons, + suggested_action, + } } } +impl SuggestedAction { + fn from_score(score: f64, thresholds: &RiskThresholds) -> Self { + if score >= thresholds.block { + Self::Block + } else if score >= thresholds.confirm { + Self::RequireConfirmation + } else if score >= thresholds.review { + Self::Review + } else { + Self::Allow + } + } +} + +fn base_tool_risk(tool_name: &str) -> (f64, Option<&'static str>) { + match tool_name { + "bash" => ( + 0.20, + Some("shell execution can modify local or shared state"), + ), + "write" | "multiedit" => (0.15, Some("writes files directly")), + "edit" => (0.10, Some("modifies existing files")), + _ => (0.05, None), + } +} + +fn assess_file_sensitivity(input: &str) -> (f64, Option) { + const SECRET_PATTERNS: &[&str] = &[ + ".env", + "secret", + "credential", + "token", + "api_key", + "apikey", + "auth", + "id_rsa", + ".pem", + ".key", + ]; + const SHARED_INFRA_PATTERNS: &[&str] = &[ + "cargo.toml", + "package.json", + "dockerfile", + ".github/workflows", + "schema", + "migration", + "production", + ]; + + if contains_any(input, SECRET_PATTERNS) { + ( + 0.25, + Some("targets a sensitive file or credential surface".to_string()), + ) + } else if contains_any(input, SHARED_INFRA_PATTERNS) { + ( + 0.15, + Some("targets shared infrastructure or release-critical files".to_string()), + ) + } else { + (0.0, None) + } +} + +fn assess_blast_radius(input: &str) -> (f64, Option) { + const LARGE_SCOPE_PATTERNS: &[&str] = &[ + "**", + "/*", + "--all", + "--recursive", + "entire repo", + "all files", + "across src/", + "find ", + " xargs ", + ]; + const SHARED_STATE_PATTERNS: &[&str] = &[ + "git push --force", + "git push -f", + "origin main", + "origin master", + "rm -rf .", + "rm -rf /", + ]; + + if contains_any(input, SHARED_STATE_PATTERNS) { + ( + 0.35, + Some("has a broad blast radius across shared state or history".to_string()), + ) + } else if contains_any(input, LARGE_SCOPE_PATTERNS) { + ( + 0.25, + Some("has a broad blast radius across multiple files or directories".to_string()), + ) + } else { + (0.0, None) + } +} + +fn assess_irreversibility(input: &str) -> (f64, Option) { + const HIGH_IRREVERSIBILITY_PATTERNS: &[&str] = &[ + "rm -rf", + "git reset --hard", + "git clean -fd", + "drop database", + "drop table", + "truncate ", + "shred ", + ]; + const MODERATE_IRREVERSIBILITY_PATTERNS: &[&str] = + &["rm -f", "git push --force", "git push -f", "delete from"]; + + if contains_any(input, HIGH_IRREVERSIBILITY_PATTERNS) { + ( + 0.45, + Some("includes an irreversible or destructive operation".to_string()), + ) + } else if contains_any(input, MODERATE_IRREVERSIBILITY_PATTERNS) { + ( + 0.40, + Some("includes an irreversible or difficult-to-undo operation".to_string()), + ) + } else { + (0.0, None) + } +} + +fn contains_any(input: &str, patterns: &[&str]) -> bool { + patterns.iter().any(|pattern| input.contains(pattern)) +} + pub fn log_tool_call(db: &StateStore, event: &ToolCallEvent) -> Result<()> { db.send_message( &event.session_id, @@ -52,3 +220,72 @@ pub fn log_tool_call(db: &StateStore, event: &ToolCallEvent) -> Result<()> { )?; Ok(()) } + +#[cfg(test)] +mod tests { + use super::{SuggestedAction, ToolCallEvent}; + use crate::config::Config; + + #[test] + fn computes_sensitive_file_risk() { + let assessment = ToolCallEvent::compute_risk( + "Write", + "Update .env.production with rotated API token", + &Config::RISK_THRESHOLDS, + ); + + assert!(assessment.score >= Config::RISK_THRESHOLDS.review); + assert_eq!(assessment.suggested_action, SuggestedAction::Review); + assert!(assessment + .reasons + .iter() + .any(|reason| reason.contains("sensitive file"))); + } + + #[test] + fn computes_blast_radius_risk() { + let assessment = ToolCallEvent::compute_risk( + "Edit", + "Apply the same replacement across src/**/*.rs", + &Config::RISK_THRESHOLDS, + ); + + assert!(assessment.score >= Config::RISK_THRESHOLDS.review); + assert_eq!(assessment.suggested_action, SuggestedAction::Review); + assert!(assessment + .reasons + .iter() + .any(|reason| reason.contains("blast radius"))); + } + + #[test] + fn computes_irreversible_risk() { + let assessment = ToolCallEvent::compute_risk( + "Bash", + "rm -f /tmp/ecc-temp.txt", + &Config::RISK_THRESHOLDS, + ); + + assert!(assessment.score >= Config::RISK_THRESHOLDS.confirm); + assert_eq!( + assessment.suggested_action, + SuggestedAction::RequireConfirmation, + ); + assert!(assessment + .reasons + .iter() + .any(|reason| reason.contains("irreversible"))); + } + + #[test] + fn blocks_combined_high_risk_operations() { + let assessment = ToolCallEvent::compute_risk( + "Bash", + "rm -rf . && git push --force origin main", + &Config::RISK_THRESHOLDS, + ); + + assert!(assessment.score >= Config::RISK_THRESHOLDS.block); + assert_eq!(assessment.suggested_action, SuggestedAction::Block); + } +} diff --git a/ecc2/src/session/manager.rs b/ecc2/src/session/manager.rs index c08c5f0d..60d092ba 100644 --- a/ecc2/src/session/manager.rs +++ b/ecc2/src/session/manager.rs @@ -1,8 +1,8 @@ use anyhow::Result; use std::fmt; -use super::{Session, SessionMetrics, SessionState}; use super::store::StateStore; +use super::{Session, SessionMetrics, SessionState}; use crate::config::Config; use crate::worktree; diff --git a/ecc2/src/session/store.rs b/ecc2/src/session/store.rs index b412f188..515bcc98 100644 --- a/ecc2/src/session/store.rs +++ b/ecc2/src/session/store.rs @@ -170,16 +170,12 @@ impl StateStore { pub fn get_session(&self, id: &str) -> Result> { let sessions = self.list_sessions()?; - Ok(sessions.into_iter().find(|s| s.id == id || s.id.starts_with(id))) + Ok(sessions + .into_iter() + .find(|s| s.id == id || s.id.starts_with(id))) } - pub fn send_message( - &self, - from: &str, - to: &str, - content: &str, - msg_type: &str, - ) -> Result<()> { + pub fn send_message(&self, from: &str, to: &str, content: &str, msg_type: &str) -> Result<()> { self.conn.execute( "INSERT INTO messages (from_session, to_session, content, msg_type, timestamp) VALUES (?1, ?2, ?3, ?4, ?5)", diff --git a/ecc2/src/tui/dashboard.rs b/ecc2/src/tui/dashboard.rs index aca1e995..9c36a1ce 100644 --- a/ecc2/src/tui/dashboard.rs +++ b/ecc2/src/tui/dashboard.rs @@ -4,8 +4,8 @@ use ratatui::{ }; use crate::config::Config; -use crate::session::{Session, SessionState}; use crate::session::store::StateStore; +use crate::session::{Session, SessionState}; pub struct Dashboard { db: StateStore, @@ -42,7 +42,7 @@ impl Dashboard { let chunks = Layout::default() .direction(Direction::Vertical) .constraints([ - Constraint::Length(3), // Header + Constraint::Length(3), // Header Constraint::Min(10), // Main content Constraint::Length(3), // Status bar ]) @@ -79,7 +79,11 @@ impl Dashboard { } fn render_header(&self, frame: &mut Frame, area: Rect) { - let running = self.sessions.iter().filter(|s| s.state == SessionState::Running).count(); + let running = self + .sessions + .iter() + .filter(|s| s.state == SessionState::Running) + .count(); let total = self.sessions.len(); let title = format!(" ECC 2.0 | {running} running / {total} total "); @@ -90,7 +94,11 @@ impl Dashboard { Pane::Output => 1, Pane::Metrics => 2, }) - .highlight_style(Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD)); + .highlight_style( + Style::default() + .fg(Color::Cyan) + .add_modifier(Modifier::BOLD), + ); frame.render_widget(tabs, area); } @@ -110,11 +118,18 @@ impl Dashboard { SessionState::Pending => "◌", }; let style = if i == self.selected_session { - Style::default().fg(Color::Cyan).add_modifier(Modifier::BOLD) + Style::default() + .fg(Color::Cyan) + .add_modifier(Modifier::BOLD) } else { Style::default() }; - let text = format!("{state_icon} {} [{}] {}", &s.id[..8.min(s.id.len())], s.agent_type, s.task); + let text = format!( + "{state_icon} {} [{}] {}", + &s.id[..8.min(s.id.len())], + s.agent_type, + s.task + ); ListItem::new(text).style(style) }) .collect(); @@ -136,7 +151,10 @@ impl Dashboard { fn render_output(&self, frame: &mut Frame, area: Rect) { let content = if let Some(session) = self.sessions.get(self.selected_session) { - format!("Agent output for session {}...\n\n(Live streaming coming soon)", session.id) + format!( + "Agent output for session {}...\n\n(Live streaming coming soon)", + session.id + ) } else { "No sessions. Press 'n' to start one.".to_string() }; diff --git a/ecc2/src/worktree/mod.rs b/ecc2/src/worktree/mod.rs index 50306f2a..8ac1974b 100644 --- a/ecc2/src/worktree/mod.rs +++ b/ecc2/src/worktree/mod.rs @@ -28,7 +28,11 @@ pub fn create_for_session(session_id: &str, cfg: &Config) -> Result