From 8303970258d52993f60229bfb9563640b3ff9969 Mon Sep 17 00:00:00 2001 From: Affaan Mustafa Date: Tue, 24 Mar 2026 03:39:53 -0700 Subject: [PATCH] feat(ecc2): add tool risk scoring and actions --- ecc2/src/config/mod.rs | 28 ++++ ecc2/src/observability/mod.rs | 279 ++++++++++++++++++++++++++++++---- ecc2/src/session/manager.rs | 1 + 3 files changed, 282 insertions(+), 26 deletions(-) diff --git a/ecc2/src/config/mod.rs b/ecc2/src/config/mod.rs index 16a7cf86..ec510fd9 100644 --- a/ecc2/src/config/mod.rs +++ b/ecc2/src/config/mod.rs @@ -11,6 +11,14 @@ pub enum PaneLayout { Grid, } +#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)] +#[serde(default)] +pub struct RiskThresholds { + pub review: f64, + pub confirm: f64, + pub block: f64, +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(default)] pub struct Config { @@ -25,6 +33,7 @@ pub struct Config { pub token_budget: u64, pub theme: Theme, pub pane_layout: PaneLayout, + pub risk_thresholds: RiskThresholds, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -48,11 +57,18 @@ impl Default for Config { token_budget: 500_000, theme: Theme::Dark, pane_layout: PaneLayout::Horizontal, + risk_thresholds: Self::RISK_THRESHOLDS, } } } impl Config { + pub const RISK_THRESHOLDS: RiskThresholds = RiskThresholds { + review: 0.35, + confirm: 0.60, + block: 0.85, + }; + pub fn load() -> Result { let config_path = dirs::home_dir() .unwrap_or_else(|| PathBuf::from(".")) @@ -69,6 +85,12 @@ impl Config { } } +impl Default for RiskThresholds { + fn default() -> Self { + Config::RISK_THRESHOLDS + } +} + #[cfg(test)] mod tests { use super::{Config, PaneLayout}; @@ -100,6 +122,7 @@ theme = "Dark" assert_eq!(config.cost_budget_usd, defaults.cost_budget_usd); assert_eq!(config.token_budget, defaults.token_budget); assert_eq!(config.pane_layout, defaults.pane_layout); + assert_eq!(config.risk_thresholds, defaults.risk_thresholds); } #[test] @@ -113,4 +136,9 @@ theme = "Dark" assert_eq!(config.pane_layout, PaneLayout::Grid); } + + #[test] + fn default_risk_thresholds_are_applied() { + assert_eq!(Config::default().risk_thresholds, Config::RISK_THRESHOLDS); + } } diff --git a/ecc2/src/observability/mod.rs b/ecc2/src/observability/mod.rs index 39128e2e..80d0c8a2 100644 --- a/ecc2/src/observability/mod.rs +++ b/ecc2/src/observability/mod.rs @@ -1,6 +1,7 @@ use anyhow::{bail, Result}; use serde::{Deserialize, Serialize}; +use crate::config::{Config, RiskThresholds}; use crate::session::store::StateStore; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -13,6 +14,22 @@ pub struct ToolCallEvent { pub risk_score: f64, } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub struct RiskAssessment { + pub score: f64, + pub reasons: Vec, + pub suggested_action: SuggestedAction, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum SuggestedAction { + Allow, + Review, + RequireConfirmation, + Block, +} + impl ToolCallEvent { pub fn new( session_id: impl Into, @@ -26,7 +43,8 @@ impl ToolCallEvent { Self { session_id: session_id.into(), - risk_score: Self::compute_risk(&tool_name, &input_summary), + risk_score: Self::compute_risk(&tool_name, &input_summary, &Config::RISK_THRESHOLDS) + .score, tool_name, input_summary, output_summary: output_summary.into(), @@ -34,35 +52,186 @@ impl ToolCallEvent { } } - /// Compute risk score based on tool type and input patterns. - pub fn compute_risk(tool_name: &str, input: &str) -> f64 { - let mut score: f64 = 0.0; + /// Compute risk from the tool type and input characteristics. + pub fn compute_risk( + tool_name: &str, + input: &str, + thresholds: &RiskThresholds, + ) -> RiskAssessment { + let normalized_tool = tool_name.to_ascii_lowercase(); + let normalized_input = input.to_ascii_lowercase(); + let mut score = 0.0; + let mut reasons = Vec::new(); - // Destructive tools get higher base risk - match tool_name { - "Bash" => score += 0.3, - "Write" => score += 0.2, - "Edit" => score += 0.1, - _ => score += 0.05, + let (base_score, base_reason) = base_tool_risk(&normalized_tool); + score += base_score; + if let Some(reason) = base_reason { + reasons.push(reason.to_string()); } - // Dangerous patterns in bash commands - if tool_name == "Bash" { - if input.contains("rm -rf") || input.contains("--force") { - score += 0.4; - } - if input.contains("git push") || input.contains("git reset") { - score += 0.3; - } - if input.contains("sudo") || input.contains("chmod 777") { - score += 0.5; - } + let (file_sensitivity_score, file_sensitivity_reason) = + assess_file_sensitivity(&normalized_input); + score += file_sensitivity_score; + if let Some(reason) = file_sensitivity_reason { + reasons.push(reason); } - score.min(1.0) + let (blast_radius_score, blast_radius_reason) = assess_blast_radius(&normalized_input); + score += blast_radius_score; + if let Some(reason) = blast_radius_reason { + reasons.push(reason); + } + + let (irreversibility_score, irreversibility_reason) = + assess_irreversibility(&normalized_input); + score += irreversibility_score; + if let Some(reason) = irreversibility_reason { + reasons.push(reason); + } + + let score = score.clamp(0.0, 1.0); + let suggested_action = SuggestedAction::from_score(score, thresholds); + + RiskAssessment { + score, + reasons, + suggested_action, + } } } +impl SuggestedAction { + fn from_score(score: f64, thresholds: &RiskThresholds) -> Self { + if score >= thresholds.block { + Self::Block + } else if score >= thresholds.confirm { + Self::RequireConfirmation + } else if score >= thresholds.review { + Self::Review + } else { + Self::Allow + } + } +} + +fn base_tool_risk(tool_name: &str) -> (f64, Option<&'static str>) { + match tool_name { + "bash" => ( + 0.20, + Some("shell execution can modify local or shared state"), + ), + "write" | "multiedit" => (0.15, Some("writes files directly")), + "edit" => (0.10, Some("modifies existing files")), + _ => (0.05, None), + } +} + +fn assess_file_sensitivity(input: &str) -> (f64, Option) { + const SECRET_PATTERNS: &[&str] = &[ + ".env", + "secret", + "credential", + "token", + "api_key", + "apikey", + "auth", + "id_rsa", + ".pem", + ".key", + ]; + const SHARED_INFRA_PATTERNS: &[&str] = &[ + "cargo.toml", + "package.json", + "dockerfile", + ".github/workflows", + "schema", + "migration", + "production", + ]; + + if contains_any(input, SECRET_PATTERNS) { + ( + 0.25, + Some("targets a sensitive file or credential surface".to_string()), + ) + } else if contains_any(input, SHARED_INFRA_PATTERNS) { + ( + 0.15, + Some("targets shared infrastructure or release-critical files".to_string()), + ) + } else { + (0.0, None) + } +} + +fn assess_blast_radius(input: &str) -> (f64, Option) { + const LARGE_SCOPE_PATTERNS: &[&str] = &[ + "**", + "/*", + "--all", + "--recursive", + "entire repo", + "all files", + "across src/", + "find ", + " xargs ", + ]; + const SHARED_STATE_PATTERNS: &[&str] = &[ + "git push --force", + "git push -f", + "origin main", + "origin master", + "rm -rf .", + "rm -rf /", + ]; + + if contains_any(input, SHARED_STATE_PATTERNS) { + ( + 0.35, + Some("has a broad blast radius across shared state or history".to_string()), + ) + } else if contains_any(input, LARGE_SCOPE_PATTERNS) { + ( + 0.25, + Some("has a broad blast radius across multiple files or directories".to_string()), + ) + } else { + (0.0, None) + } +} + +fn assess_irreversibility(input: &str) -> (f64, Option) { + const HIGH_IRREVERSIBILITY_PATTERNS: &[&str] = &[ + "rm -rf", + "git reset --hard", + "git clean -fd", + "drop database", + "drop table", + "truncate ", + "shred ", + ]; + const MODERATE_IRREVERSIBILITY_PATTERNS: &[&str] = + &["rm -f", "git push --force", "git push -f", "delete from"]; + + if contains_any(input, HIGH_IRREVERSIBILITY_PATTERNS) { + ( + 0.45, + Some("includes an irreversible or destructive operation".to_string()), + ) + } else if contains_any(input, MODERATE_IRREVERSIBILITY_PATTERNS) { + ( + 0.40, + Some("includes an irreversible or difficult-to-undo operation".to_string()), + ) + } else { + (0.0, None) + } +} + +fn contains_any(input: &str, patterns: &[&str]) -> bool { + patterns.iter().any(|pattern| input.contains(pattern)) +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolLogEntry { pub id: i64, @@ -121,7 +290,8 @@ pub fn log_tool_call(db: &StateStore, event: &ToolCallEvent) -> Result= Config::RISK_THRESHOLDS.review); + assert_eq!(assessment.suggested_action, SuggestedAction::Review); + assert!(assessment + .reasons + .iter() + .any(|reason| reason.contains("sensitive file"))); + } + + #[test] + fn computes_blast_radius_risk() { + let assessment = ToolCallEvent::compute_risk( + "Edit", + "Apply the same replacement across src/**/*.rs", + &Config::RISK_THRESHOLDS, + ); + + assert!(assessment.score >= Config::RISK_THRESHOLDS.review); + assert_eq!(assessment.suggested_action, SuggestedAction::Review); + assert!(assessment + .reasons + .iter() + .any(|reason| reason.contains("blast radius"))); + } + + #[test] + fn computes_irreversible_risk() { + let assessment = ToolCallEvent::compute_risk( + "Bash", + "rm -f /tmp/ecc-temp.txt", + &Config::RISK_THRESHOLDS, + ); + + assert!(assessment.score >= Config::RISK_THRESHOLDS.confirm); + assert_eq!( + assessment.suggested_action, + SuggestedAction::RequireConfirmation, + ); + assert!(assessment + .reasons + .iter() + .any(|reason| reason.contains("irreversible"))); + } + + #[test] + fn blocks_combined_high_risk_operations() { + let assessment = ToolCallEvent::compute_risk( + "Bash", + "rm -rf . && git push --force origin main", + &Config::RISK_THRESHOLDS, + ); + + assert!(assessment.score >= Config::RISK_THRESHOLDS.block); + assert_eq!(assessment.suggested_action, SuggestedAction::Block); } #[test] diff --git a/ecc2/src/session/manager.rs b/ecc2/src/session/manager.rs index 67099780..612965ec 100644 --- a/ecc2/src/session/manager.rs +++ b/ecc2/src/session/manager.rs @@ -443,6 +443,7 @@ mod tests { token_budget: 500_000, theme: Theme::Dark, pane_layout: PaneLayout::Horizontal, + risk_thresholds: Config::RISK_THRESHOLDS, } }