|
diff --git a/codex-rs/core/src/codex.rs b/codex-rs/core/src/codex.rs |
|
index 8b792887..ee1c10a9 100644 |
|
--- a/codex-rs/core/src/codex.rs |
|
+++ b/codex-rs/core/src/codex.rs |
|
@@ -49,6 +49,7 @@ use crate::client::ModelClient; |
|
use crate::client_common::Prompt; |
|
use crate::client_common::ResponseEvent; |
|
use crate::config::Config; |
|
+use crate::config::MemoryFirstConfig; |
|
use crate::config_types::ShellEnvironmentPolicy; |
|
use crate::conversation_history::ConversationHistory; |
|
use crate::environment_context::EnvironmentContext; |
|
@@ -283,6 +284,7 @@ pub(crate) struct TurnContext { |
|
pub(crate) tools_config: ToolsConfig, |
|
pub(crate) is_review_mode: bool, |
|
pub(crate) final_output_json_schema: Option<Value>, |
|
+ pub(crate) memory_first: Option<MemoryFirstConfig>, |
|
} |
|
|
|
impl TurnContext { |
|
@@ -451,6 +453,7 @@ impl Session { |
|
cwd, |
|
is_review_mode: false, |
|
final_output_json_schema: None, |
|
+ memory_first: config.memory_first.clone(), |
|
}; |
|
let services = SessionServices { |
|
mcp_connection_manager, |
|
@@ -1158,6 +1161,7 @@ async fn submission_loop( |
|
cwd: new_cwd.clone(), |
|
is_review_mode: false, |
|
final_output_json_schema: None, |
|
+ memory_first: prev.memory_first.clone(), |
|
}; |
|
|
|
// Install the new persistent context for subsequent tasks/turns. |
|
@@ -1243,6 +1247,7 @@ async fn submission_loop( |
|
cwd, |
|
is_review_mode: false, |
|
final_output_json_schema, |
|
+ memory_first: config.memory_first.clone(), |
|
}; |
|
|
|
// if the environment context has changed, record it in the conversation history |
|
@@ -1493,6 +1498,7 @@ async fn spawn_review_thread( |
|
cwd: parent_turn_context.cwd.clone(), |
|
is_review_mode: true, |
|
final_output_json_schema: None, |
|
+ memory_first: parent_turn_context.memory_first.clone(), |
|
}; |
|
|
|
// Seed the child task with the review prompt as the initial user message. |
|
@@ -1848,9 +1854,268 @@ async fn run_turn( |
|
&turn_context.tools_config, |
|
Some(sess.services.mcp_connection_manager.list_all_tools()), |
|
); |
|
+ // Optionally run a memory-first pre-turn hook which can inject context |
|
+ // or short-circuit the model call when configured and enabled. |
|
+ let mut adjusted_input = input; |
|
+ if let Some(cfg) = turn_context.memory_first.as_ref().filter(|c| c.enable) { |
|
+ // Gather latest user text from input for the memory query |
|
+ let latest_user_text: Option<String> = |
|
+ adjusted_input.iter().rev().find_map(|ri| match ri { |
|
+ ResponseItem::Message { role, content, .. } if role == "user" => { |
|
+ let mut s = String::new(); |
|
+ for c in content { |
|
+ if let codex_protocol::models::ContentItem::InputText { text } = c { |
|
+ s.push_str(text); |
|
+ s.push('\n'); |
|
+ } |
|
+ } |
|
+ let t = s.trim().to_owned(); |
|
+ if t.is_empty() { None } else { Some(t) } |
|
+ } |
|
+ _ => None, |
|
+ }); |
|
+ |
|
+ if let Some(q) = latest_user_text { |
|
+ use serde_json::json; |
|
+ let server = cfg.server.clone().unwrap_or_else(|| "memory".to_string()); |
|
+ let tool = cfg |
|
+ .tool |
|
+ .clone() |
|
+ .unwrap_or_else(|| "memory_search".to_string()); |
|
+ // Defaults tuned for stability across transports/servers. |
|
+ let timeout_ms = cfg.timeout_ms.unwrap_or(1000); |
|
+ let short_circuit = cfg.short_circuit.unwrap_or(true); |
|
+ // Clamp threshold into [0.0, 1.0]. |
|
+ let mut threshold = cfg.confidence_threshold.unwrap_or(0.80); |
|
+ if !(0.0..=1.0).contains(&threshold) { |
|
+ threshold = threshold.clamp(0.0, 1.0); |
|
+ } |
|
+ let inject_cap = cfg.inject_max_chars.unwrap_or(1500); |
|
+ let mut args = json!({ "q": q, "k": cfg.k.unwrap_or(5) }); |
|
+ if let Some(scope) = &cfg.scope { |
|
+ args["scope"] = json!(scope); |
|
+ } |
|
+ if tool == "memory_recall" |
|
+ && let Some(d) = cfg.depth |
|
+ { |
|
+ args["depth"] = json!(d); |
|
+ } |
|
+ |
|
+ // Timebox the MCP tool call with a best-effort timeout. |
|
+ let call_fut = |
|
+ sess.services |
|
+ .mcp_connection_manager |
|
+ .call_tool(&server, &tool, Some(args)); |
|
+ let call_res = |
|
+ tokio::time::timeout(std::time::Duration::from_millis(timeout_ms), call_fut).await; |
|
+ |
|
+ let call_res = match call_res { |
|
+ Ok(Ok(r)) => r, |
|
+ _ if cfg.required => { |
|
+ return Err(CodexErr::UnsupportedOperation(format!( |
|
+ "memory_first (required) failed for {server}/{tool}" |
|
+ ))); |
|
+ } |
|
+ _ => { |
|
+ tracing::warn!( |
|
+ "memory_first: timeout/MCP error (best-effort; proceeding to model)" |
|
+ ); |
|
+ // fall through unchanged |
|
+ mcp_types::CallToolResult { |
|
+ content: Vec::new(), |
|
+ is_error: None, |
|
+ structured_content: None, |
|
+ } |
|
+ } |
|
+ }; |
|
+ |
|
+ // Honor explicit error flag from MCP tool result. |
|
+ if call_res.is_error.unwrap_or(false) { |
|
+ if cfg.required { |
|
+ return Err(CodexErr::UnsupportedOperation( |
|
+ "memory_first (required) returned error".to_string(), |
|
+ )); |
|
+ } else { |
|
+ tracing::warn!( |
|
+ "memory_first: tool returned is_error=true; proceeding to model" |
|
+ ); |
|
+ } |
|
+ } |
|
+ |
|
+ // Extract top items from either JSON or text content |
|
+ fn extract_top_items( |
|
+ res: &mcp_types::CallToolResult, |
|
+ ) -> Vec<(String, String, f32, String, Option<String>)> { |
|
+ let mut out = Vec::new(); |
|
+ // Prefer structured_content when present |
|
+ if let Some(sc) = &res.structured_content |
|
+ && let Ok(v) = serde_json::to_value(sc) |
|
+ && let Some(arr) = v |
|
+ .get("items") |
|
+ .and_then(|x| x.as_array()) |
|
+ .or_else(|| v.as_array()) |
|
+ { |
|
+ for it in arr { |
|
+ let title = it |
|
+ .get("title") |
|
+ .and_then(|x| x.as_str()) |
|
+ .unwrap_or("") |
|
+ .to_string(); |
|
+ let snippet = it |
|
+ .get("snippet") |
|
+ .and_then(|x| x.as_str()) |
|
+ .unwrap_or("") |
|
+ .to_string(); |
|
+ let score = it |
|
+ .get("score") |
|
+ .and_then(|x| { |
|
+ x.as_f64() |
|
+ .or_else(|| x.as_str().and_then(|s| s.parse::<f64>().ok())) |
|
+ }) |
|
+ .unwrap_or(0.0) as f32; |
|
+ let key = it |
|
+ .get("key") |
|
+ .and_then(|x| x.as_str()) |
|
+ .unwrap_or("") |
|
+ .to_string(); |
|
+ let ts = it |
|
+ .get("ts") |
|
+ .and_then(|x| x.as_str()) |
|
+ .map(std::string::ToString::to_string); |
|
+ if !key.is_empty() || !title.is_empty() { |
|
+ out.push((title, snippet, score, key, ts)); |
|
+ } |
|
+ } |
|
+ } |
|
+ for part in &res.content { |
|
+ if let mcp_types::ContentBlock::TextContent(t) = part |
|
+ && let Ok(v) = serde_json::from_str::<serde_json::Value>(&t.text) |
|
+ && let Some(arr) = v |
|
+ .get("items") |
|
+ .and_then(|x| x.as_array()) |
|
+ .or_else(|| v.as_array()) |
|
+ { |
|
+ for it in arr { |
|
+ let title = it |
|
+ .get("title") |
|
+ .and_then(|x| x.as_str()) |
|
+ .unwrap_or("") |
|
+ .to_string(); |
|
+ let snippet = it |
|
+ .get("snippet") |
|
+ .and_then(|x| x.as_str()) |
|
+ .unwrap_or("") |
|
+ .to_string(); |
|
+ let score = it |
|
+ .get("score") |
|
+ .and_then(|x| { |
|
+ x.as_f64() |
|
+ .or_else(|| x.as_str().and_then(|s| s.parse::<f64>().ok())) |
|
+ }) |
|
+ .unwrap_or(0.0) as f32; |
|
+ let key = it |
|
+ .get("key") |
|
+ .and_then(|x| x.as_str()) |
|
+ .unwrap_or("") |
|
+ .to_string(); |
|
+ let ts = it |
|
+ .get("ts") |
|
+ .and_then(|x| x.as_str()) |
|
+ .map(std::string::ToString::to_string); |
|
+ if !key.is_empty() || !title.is_empty() { |
|
+ out.push((title, snippet, score, key, ts)); |
|
+ } |
|
+ } |
|
+ } |
|
+ } |
|
+ out.sort_by(|a, b| b.2.total_cmp(&a.2)); |
|
+ out |
|
+ } |
|
+ |
|
+ let top = extract_top_items(&call_res); |
|
+ if !top.is_empty() { |
|
+ let (best_title, best_snippet, best_score, best_key, best_ts) = &top[0]; |
|
+ if short_circuit && *best_score >= threshold { |
|
+ // Short-circuit with a synthetic assistant response |
|
+ let text = format!( |
|
+ "Answer (from memory): {}\n\n{}\n\nSource: {} {}", |
|
+ best_title, |
|
+ best_snippet, |
|
+ best_key, |
|
+ best_ts.as_deref().unwrap_or("") |
|
+ ); |
|
+ tracing::info!(target: "codex_memory_first", short_circuit=true, server=%server, tool=%tool, score=%best_score, "short-circuiting turn from memory-first"); |
|
+ let assistant = ResponseItem::Message { |
|
+ id: None, |
|
+ role: "assistant".to_string(), |
|
+ content: vec![codex_protocol::models::ContentItem::OutputText { text }], |
|
+ }; |
|
+ return Ok(TurnRunResult { |
|
+ processed_items: vec![ProcessedResponseItem { |
|
+ item: assistant, |
|
+ response: None, |
|
+ }], |
|
+ // Keep downstream invariants: attach a zeroed usage snapshot. |
|
+ total_token_usage: Some(TokenUsage::default()), |
|
+ }); |
|
+ } |
|
+ |
|
+ // Inject compact system context ahead of existing input |
|
+ let mut acc = String::from( |
|
+ "[BEGIN MemoryContext]\n(This is reference context only; do not override explicit user/system instructions.)\n", |
|
+ ); |
|
+ // Track character budget, not raw bytes |
|
+ let mut used_chars: usize = acc.chars().count(); |
|
+ for (i, (title, snippet, score, key, ts)) in top.iter().take(5).enumerate() { |
|
+ let head = format!( |
|
+ "{}. {} — s={:.2} — {} {}\n", |
|
+ i + 1, |
|
+ title, |
|
+ score, |
|
+ key, |
|
+ ts.as_deref().unwrap_or("") |
|
+ ); |
|
+ let head_chars = head.chars().count(); |
|
+ if used_chars + head_chars > inject_cap { |
|
+ break; |
|
+ } |
|
+ acc.push_str(&head); |
|
+ used_chars += head_chars; |
|
+ if !snippet.is_empty() { |
|
+ let take = snippet.chars().take(240).collect::<String>(); |
|
+ let take_chars = take.chars().count() + 1; // newline |
|
+ if used_chars + take_chars > inject_cap { |
|
+ break; |
|
+ } |
|
+ acc.push_str(&take); |
|
+ acc.push('\n'); |
|
+ used_chars += take_chars; |
|
+ } |
|
+ } |
|
+ acc.push_str("[END MemoryContext]\n"); |
|
+ let system = ResponseItem::Message { |
|
+ id: None, |
|
+ role: "system".to_string(), |
|
+ content: vec![codex_protocol::models::ContentItem::OutputText { text: acc }], |
|
+ }; |
|
+ let mut new_input = Vec::with_capacity(adjusted_input.len() + 1); |
|
+ new_input.push(system); |
|
+ new_input.extend(adjusted_input.into_iter()); |
|
+ adjusted_input = new_input; |
|
+ } else if cfg.required { |
|
+ return Err(CodexErr::UnsupportedOperation( |
|
+ "memory_first (required) returned no items".to_string(), |
|
+ )); |
|
+ } |
|
+ } else if cfg.required { |
|
+ return Err(CodexErr::UnsupportedOperation( |
|
+ "memory_first (required) has no user text to query".to_string(), |
|
+ )); |
|
+ } |
|
+ } |
|
|
|
let prompt = Prompt { |
|
- input, |
|
+ input: adjusted_input, |
|
tools, |
|
base_instructions_override: turn_context.base_instructions.clone(), |
|
output_schema: turn_context.final_output_json_schema.clone(), |
|
@@ -3409,6 +3674,7 @@ mod tests { |
|
tools_config, |
|
is_review_mode: false, |
|
final_output_json_schema: None, |
|
+ memory_first: config.memory_first.clone(), |
|
}; |
|
let services = SessionServices { |
|
mcp_connection_manager: McpConnectionManager::default(), |
|
@@ -3476,6 +3742,7 @@ mod tests { |
|
tools_config, |
|
is_review_mode: false, |
|
final_output_json_schema: None, |
|
+ memory_first: config.memory_first.clone(), |
|
}); |
|
let services = SessionServices { |
|
mcp_connection_manager: McpConnectionManager::default(), |
|
@@ -3599,6 +3866,91 @@ mod tests { |
|
); |
|
} |
|
|
|
+ #[tokio::test] |
|
+ async fn memory_first_required_errors_when_no_user_text() { |
|
+ let (session, mut turn_context) = make_session_and_context(); |
|
+ // Enable memory_first with required=true |
|
+ turn_context.memory_first = Some(super::MemoryFirstConfig { |
|
+ enable: true, |
|
+ required: true, |
|
+ server: None, |
|
+ tool: None, |
|
+ scope: None, |
|
+ k: None, |
|
+ depth: None, |
|
+ timeout_ms: Some(50), |
|
+ short_circuit: Some(false), |
|
+ confidence_threshold: Some(0.8), |
|
+ inject_max_chars: Some(256), |
|
+ }); |
|
+ |
|
+ // No user text in input → should fail‑closed |
|
+ let input: Vec<ResponseItem> = Vec::new(); |
|
+ let mut tracker = TurnDiffTracker::new(); |
|
+ let res = super::run_turn( |
|
+ &session, |
|
+ &turn_context, |
|
+ &mut tracker, |
|
+ "sub-mem-none".to_string(), |
|
+ input, |
|
+ ) |
|
+ .await; |
|
+ match res { |
|
+ Err(super::CodexErr::UnsupportedOperation(msg)) => { |
|
+ assert!(msg.contains("no user text"), "unexpected msg: {msg}"); |
|
+ } |
|
+ other => panic!("expected fail-closed for no user text, got: {other:?}"), |
|
+ } |
|
+ } |
|
+ |
|
+ #[tokio::test] |
|
+ async fn memory_first_required_errors_when_no_items() { |
|
+ let (session, mut turn_context) = make_session_and_context(); |
|
+ turn_context.memory_first = Some(super::MemoryFirstConfig { |
|
+ enable: true, |
|
+ required: true, |
|
+ server: Some("memory".to_string()), // unknown server → immediate error path |
|
+ tool: Some("memory_search".to_string()), |
|
+ scope: None, |
|
+ k: Some(5), |
|
+ depth: None, |
|
+ timeout_ms: Some(50), |
|
+ short_circuit: Some(false), |
|
+ confidence_threshold: Some(0.9), |
|
+ inject_max_chars: Some(256), |
|
+ }); |
|
+ |
|
+ // Provide user text but the MCP call will fail → treated as empty results in best‑effort path, |
|
+ // but here required=true should surface a fail‑closed error later when no items are found. |
|
+ let input: Vec<ResponseItem> = vec![ResponseItem::Message { |
|
+ id: None, |
|
+ role: "user".to_string(), |
|
+ content: vec![ContentItem::InputText { |
|
+ text: "test query".to_string(), |
|
+ }], |
|
+ }]; |
|
+ let mut tracker = TurnDiffTracker::new(); |
|
+ let res = super::run_turn( |
|
+ &session, |
|
+ &turn_context, |
|
+ &mut tracker, |
|
+ "sub-mem-empty".to_string(), |
|
+ input, |
|
+ ) |
|
+ .await; |
|
+ match res { |
|
+ Err(super::CodexErr::UnsupportedOperation(msg)) => { |
|
+ // Depending on transport timing, this may fail early (unknown server) |
|
+ // or later (no items). Accept either phrasing. |
|
+ assert!( |
|
+ msg.contains("returned no items") || msg.contains("failed for"), |
|
+ "unexpected msg: {msg}" |
|
+ ); |
|
+ } |
|
+ other => panic!("expected fail-closed for empty/failed results, got: {other:?}"), |
|
+ } |
|
+ } |
|
+ |
|
fn sample_rollout( |
|
session: &Session, |
|
turn_context: &TurnContext, |
|
diff --git a/codex-rs/core/src/config.rs b/codex-rs/core/src/config.rs |
|
index 292b9f7b..ec4f4e59 100644 |
|
--- a/codex-rs/core/src/config.rs |
|
+++ b/codex-rs/core/src/config.rs |
|
@@ -199,6 +199,12 @@ pub struct Config { |
|
/// All characters are inserted as they are received, and no buffering |
|
/// or placeholder replacement will occur for fast keypress bursts. |
|
pub disable_paste_burst: bool, |
|
+ |
|
+ /// Optional memory-first pre-turn hook configuration. |
|
+ pub memory_first: Option<MemoryFirstConfig>, |
|
+ |
|
+ /// Optional generic pre-hooks to run before sending the initial request. |
|
+ pub pre_hooks: Option<PreHooksConfig>, |
|
} |
|
|
|
impl Config { |
|
@@ -719,6 +725,14 @@ pub struct ConfigToml { |
|
/// All characters are inserted as they are received, and no buffering |
|
/// or placeholder replacement will occur for fast keypress bursts. |
|
pub disable_paste_burst: Option<bool>, |
|
+ |
|
+ /// Optional memory-first pre-turn hook configuration. |
|
+ #[serde(default)] |
|
+ pub memory_first: Option<MemoryFirstToml>, |
|
+ |
|
+ /// Optional generic pre-hooks that run prior to submitting the prompt. |
|
+ #[serde(default)] |
|
+ pub pre_hooks: Option<PreHooksToml>, |
|
} |
|
|
|
impl From<ConfigToml> for UserSavedConfig { |
|
@@ -842,6 +856,128 @@ impl ConfigToml { |
|
} |
|
} |
|
|
|
+/// TOML representation of the memory-first pre-turn hook config. |
|
+#[derive(Deserialize, Debug, Clone, Default, PartialEq)] |
|
+pub struct MemoryFirstToml { |
|
+ #[serde(default)] |
|
+ pub enable: Option<bool>, |
|
+ #[serde(default)] |
|
+ pub required: Option<bool>, |
|
+ pub server: Option<String>, |
|
+ pub tool: Option<String>, |
|
+ pub scope: Option<String>, |
|
+ pub k: Option<u32>, |
|
+ pub depth: Option<u32>, |
|
+ pub timeout_ms: Option<u64>, |
|
+ pub short_circuit: Option<bool>, |
|
+ pub confidence_threshold: Option<f32>, |
|
+ pub inject_max_chars: Option<usize>, |
|
+} |
|
+ |
|
+/// Runtime configuration for the optional memory-first pre-turn hook. |
|
+#[derive(Debug, Clone, PartialEq)] |
|
+pub struct MemoryFirstConfig { |
|
+ pub enable: bool, |
|
+ pub required: bool, |
|
+ pub server: Option<String>, |
|
+ pub tool: Option<String>, |
|
+ pub scope: Option<String>, |
|
+ pub k: Option<u32>, |
|
+ pub depth: Option<u32>, |
|
+ pub timeout_ms: Option<u64>, |
|
+ pub short_circuit: Option<bool>, |
|
+ pub confidence_threshold: Option<f32>, |
|
+ pub inject_max_chars: Option<usize>, |
|
+} |
|
+ |
|
+impl From<MemoryFirstToml> for MemoryFirstConfig { |
|
+ fn from(t: MemoryFirstToml) -> Self { |
|
+ Self { |
|
+ enable: t.enable.unwrap_or(false), |
|
+ required: t.required.unwrap_or(false), |
|
+ server: t.server, |
|
+ tool: t.tool, |
|
+ scope: t.scope, |
|
+ k: t.k, |
|
+ depth: t.depth, |
|
+ timeout_ms: t.timeout_ms, |
|
+ short_circuit: t.short_circuit, |
|
+ confidence_threshold: t.confidence_threshold, |
|
+ inject_max_chars: t.inject_max_chars, |
|
+ } |
|
+ } |
|
+} |
|
+ |
|
+/// TOML representation of generic pre-hooks. |
|
+#[derive(Deserialize, Debug, Clone, Default, PartialEq)] |
|
+pub struct PreHooksToml { |
|
+ /// When true, pre-hooks are enabled. |
|
+ #[serde(default)] |
|
+ pub enable: Option<bool>, |
|
+ /// When true, any failing step is treated as fatal (unless a step overrides required=false). |
|
+ #[serde(default)] |
|
+ pub required: Option<bool>, |
|
+ /// Steps to execute in order; each step is a command and args. |
|
+ #[serde(default)] |
|
+ pub steps: Vec<PreHookStepToml>, |
|
+} |
|
+ |
|
+#[derive(Deserialize, Debug, Clone, Default, PartialEq)] |
|
+pub struct PreHookStepToml { |
|
+ /// Command and arguments to execute. |
|
+ #[serde(default)] |
|
+ pub cmd: Vec<String>, |
|
+ /// If set, overrides the global `required` for this step. |
|
+ #[serde(default)] |
|
+ pub required: Option<bool>, |
|
+ /// Optional working directory for the step. |
|
+ pub cwd: Option<PathBuf>, |
|
+ /// Optional environment variables for the step. |
|
+ #[serde(default)] |
|
+ pub env: HashMap<String, String>, |
|
+ /// Optional timeout for the step in milliseconds. |
|
+ pub timeout_ms: Option<u64>, |
|
+} |
|
+ |
|
+/// Runtime configuration for generic pre-hooks. |
|
+#[derive(Debug, Clone, PartialEq)] |
|
+pub struct PreHooksConfig { |
|
+ pub enable: bool, |
|
+ pub required: bool, |
|
+ pub steps: Vec<PreHookStep>, |
|
+} |
|
+ |
|
+#[derive(Debug, Clone, PartialEq)] |
|
+pub struct PreHookStep { |
|
+ pub cmd: Vec<String>, |
|
+ pub required: bool, |
|
+ pub cwd: Option<PathBuf>, |
|
+ pub env: HashMap<String, String>, |
|
+ pub timeout_ms: Option<u64>, |
|
+} |
|
+ |
|
+impl From<PreHooksToml> for PreHooksConfig { |
|
+ fn from(t: PreHooksToml) -> Self { |
|
+ let global_required = t.required.unwrap_or(false); |
|
+ let steps = t |
|
+ .steps |
|
+ .into_iter() |
|
+ .map(|s| PreHookStep { |
|
+ required: s.required.unwrap_or(global_required), |
|
+ cmd: s.cmd, |
|
+ cwd: s.cwd, |
|
+ env: s.env, |
|
+ timeout_ms: s.timeout_ms, |
|
+ }) |
|
+ .collect(); |
|
+ Self { |
|
+ enable: t.enable.unwrap_or(false), |
|
+ required: global_required, |
|
+ steps, |
|
+ } |
|
+ } |
|
+} |
|
+ |
|
/// Optional overrides for user configuration (e.g., from CLI flags). |
|
#[derive(Default, Debug, Clone)] |
|
pub struct ConfigOverrides { |
|
@@ -1068,6 +1204,8 @@ impl Config { |
|
.as_ref() |
|
.map(|t| t.notifications.clone()) |
|
.unwrap_or_default(), |
|
+ memory_first: cfg.memory_first.clone().map(Into::into), |
|
+ pre_hooks: cfg.pre_hooks.clone().map(Into::into), |
|
}; |
|
Ok(config) |
|
} |
|
@@ -1809,6 +1947,8 @@ model_verbosity = "high" |
|
active_profile: Some("o3".to_string()), |
|
disable_paste_burst: false, |
|
tui_notifications: Default::default(), |
|
+ memory_first: None, |
|
+ pre_hooks: None, |
|
}, |
|
o3_profile_config |
|
); |
|
@@ -1868,6 +2008,8 @@ model_verbosity = "high" |
|
active_profile: Some("gpt3".to_string()), |
|
disable_paste_burst: false, |
|
tui_notifications: Default::default(), |
|
+ memory_first: None, |
|
+ pre_hooks: None, |
|
}; |
|
|
|
assert_eq!(expected_gpt3_profile_config, gpt3_profile_config); |
|
@@ -1942,6 +2084,8 @@ model_verbosity = "high" |
|
active_profile: Some("zdr".to_string()), |
|
disable_paste_burst: false, |
|
tui_notifications: Default::default(), |
|
+ memory_first: None, |
|
+ pre_hooks: None, |
|
}; |
|
|
|
assert_eq!(expected_zdr_profile_config, zdr_profile_config); |
|
@@ -2002,6 +2146,8 @@ model_verbosity = "high" |
|
active_profile: Some("gpt5".to_string()), |
|
disable_paste_burst: false, |
|
tui_notifications: Default::default(), |
|
+ memory_first: None, |
|
+ pre_hooks: None, |
|
}; |
|
|
|
assert_eq!(expected_gpt5_profile_config, gpt5_profile_config); |
|
diff --git a/codex-rs/exec/src/cli.rs b/codex-rs/exec/src/cli.rs |
|
index 0df114cb..9c8740d6 100644 |
|
--- a/codex-rs/exec/src/cli.rs |
|
+++ b/codex-rs/exec/src/cli.rs |
|
@@ -56,6 +56,18 @@ pub struct Cli { |
|
#[arg(long = "output-schema", value_name = "FILE")] |
|
pub output_schema: Option<PathBuf>, |
|
|
|
+ /// Enable generic pre-hooks (runs before sending the prompt). |
|
+ #[arg(long = "pre-hooks-enable", default_value_t = false)] |
|
+ pub pre_hooks_enable: bool, |
|
+ |
|
+ /// Treat pre-hook failures as fatal (unless a step marks required=false). |
|
+ #[arg(long = "pre-hooks-required", default_value_t = false)] |
|
+ pub pre_hooks_required: bool, |
|
+ |
|
+ /// Command to run as a pre-hook step (may be repeated). |
|
+ #[arg(long = "pre-hook", value_name = "CMD", action = clap::ArgAction::Append)] |
|
+ pub pre_hook: Vec<String>, |
|
+ |
|
#[clap(skip)] |
|
pub config_overrides: CliConfigOverrides, |
|
|
|
diff --git a/codex-rs/exec/src/lib.rs b/codex-rs/exec/src/lib.rs |
|
index da23fb1b..7e4a2bda 100644 |
|
--- a/codex-rs/exec/src/lib.rs |
|
+++ b/codex-rs/exec/src/lib.rs |
|
@@ -4,6 +4,7 @@ mod event_processor_with_human_output; |
|
pub mod event_processor_with_json_output; |
|
pub mod exec_events; |
|
pub mod experimental_event_processor_with_json_output; |
|
+mod pre_hooks; |
|
|
|
use std::io::IsTerminal; |
|
use std::io::Read; |
|
@@ -57,8 +58,11 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any |
|
sandbox_mode: sandbox_mode_cli_arg, |
|
prompt, |
|
output_schema: output_schema_path, |
|
+ pre_hooks_enable, |
|
+ pre_hooks_required, |
|
+ pre_hook, |
|
include_plan_tool, |
|
- config_overrides, |
|
+ mut config_overrides, |
|
} = cli; |
|
|
|
// Determine the prompt source (parent or subcommand) and read from stdin if needed. |
|
@@ -172,7 +176,36 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any |
|
show_raw_agent_reasoning: oss.then_some(true), |
|
tools_web_search_request: None, |
|
}; |
|
- // Parse `-c` overrides. |
|
+ // Translate explicit pre-hook CLI flags into `-c` style overrides. |
|
+ if pre_hooks_enable { |
|
+ config_overrides |
|
+ .raw_overrides |
|
+ .push("pre_hooks.enable=true".to_string()); |
|
+ } |
|
+ if pre_hooks_required { |
|
+ config_overrides |
|
+ .raw_overrides |
|
+ .push("pre_hooks.required=true".to_string()); |
|
+ } |
|
+ if !pre_hook.is_empty() { |
|
+ // Convert repeated --pre-hook strings into a TOML array of arrays of strings |
|
+ // e.g. pre_hooks.steps = [["echo","hi"],["cargo","check"]] |
|
+ let mut parts: Vec<String> = Vec::new(); |
|
+ for raw in pre_hook.into_iter() { |
|
+ let tokens = shlex::split(&raw).unwrap_or_else(|| vec![raw.clone()]); |
|
+ let quoted: Vec<String> = tokens |
|
+ .into_iter() |
|
+ .map(|t| format!("\"{}\"", t.replace('\\', "\\\\").replace('"', "\\\""))) |
|
+ .collect(); |
|
+ parts.push(format!("[{}]", quoted.join(","))); |
|
+ } |
|
+ let array = format!("[{}]", parts.join(",")); |
|
+ config_overrides |
|
+ .raw_overrides |
|
+ .push(format!("pre_hooks.steps={array}")); |
|
+ } |
|
+ |
|
+ // Parse `-c` overrides (including translated pre-hook flags). |
|
let cli_kv_overrides = match config_overrides.parse_overrides() { |
|
Ok(v) => v, |
|
Err(e) => { |
|
@@ -253,6 +286,12 @@ pub async fn run_main(cli: Cli, codex_linux_sandbox_exe: Option<PathBuf>) -> any |
|
|
|
info!("Codex initialized with event: {session_configured:?}"); |
|
|
|
+ // Run generic pre-hooks, if enabled, before sending any input to the agent. |
|
+ if let Err(e) = pre_hooks::run_pre_hooks(&config).await { |
|
+ eprintln!("Pre-hooks failed: {e}"); |
|
+ std::process::exit(1); |
|
+ } |
|
+ |
|
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Event>(); |
|
{ |
|
let conversation = conversation.clone(); |
|
diff --git a/docs/config.md b/docs/config.md |
|
index ba204ee0..8f9acf7d 100644 |
|
--- a/docs/config.md |
|
+++ b/docs/config.md |
|
@@ -130,6 +130,32 @@ Number of times Codex will attempt to reconnect when a streaming response is int |
|
|
|
How long Codex will wait for activity on a streaming response before treating the connection as lost. Defaults to `300_000` (5 minutes). |
|
|
|
+## memory_first (pre‑turn hook) |
|
+ |
|
+Run a memory lookup before any LLM call each turn. When enabled, Codex performs a short, bounded MCP tool call using the latest user text and either short‑circuits with a high‑confidence answer or injects a compact context block at the front of the prompt. Disabled by default. |
|
+ |
|
+```toml |
|
+[memory_first] |
|
+enable = true |
|
+required = true # fail‑closed: do not call the model if the memory step errors/times out |
|
+server = "memory" # MCP server name |
|
+tool = "memory_search" # or "memory_recall" |
|
+scope = "project" # optional filter, tool‑specific |
|
+k = 5 # top‑k items to fetch |
|
+depth = 1 # for recall, optional |
|
+# Defaults tuned for stability across transports; you may lower locally. |
|
+timeout_ms = 1000 # per‑call timebox (ms) |
|
+short_circuit = true # return a synthetic answer when score ≥ threshold |
|
+confidence_threshold = 0.8 # clamped into [0.0, 1.0] |
|
+inject_max_chars = 1500 # cap for injected context characters |
|
+``` |
|
+ |
|
+Notes |
|
+- MCP server must be defined under `[mcp_servers]`; see the MCP section in this document for examples. |
|
+- When `required = true`, failures (timeout, is_error=true, empty results) abort the turn and no model call is made. |
|
+- The injected block is wrapped in `[BEGIN/END MemoryContext]` and marked as reference‑only to avoid overriding explicit instructions. |
|
+- The builder tracks characters, not bytes, to avoid overruns with multi‑byte text. |
|
+ |
|
## model_provider |
|
|
|
Identifies which provider to use from the `model_providers` map. Defaults to `"openai"`. You can override the `base_url` for the built-in `openai` provider via the `OPENAI_BASE_URL` environment variable. |