vendor: update cargo-cxcloud-bot-0.1.0
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "927a31cbf65f78c3ef6b729631b2fc35335afe06",
|
||||
"dirty": true
|
||||
},
|
||||
"path_in_vcs": "services/cxcloud-rs/crates/bot"
|
||||
}
|
||||
Generated
+2767
File diff suppressed because it is too large
Load Diff
+90
@@ -0,0 +1,90 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "cxcloud-bot"
|
||||
version = "0.1.0"
|
||||
build = false
|
||||
publish = ["cxai"]
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
readme = false
|
||||
|
||||
[[bin]]
|
||||
name = "bot-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies.anyhow]
|
||||
version = "1"
|
||||
|
||||
[dependencies.axum]
|
||||
version = "0.7"
|
||||
features = ["macros"]
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.cxcloud-common]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.cxcloud-proto]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.prost]
|
||||
version = "0.13"
|
||||
|
||||
[dependencies.prost-types]
|
||||
version = "0.13"
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tonic]
|
||||
version = "0.12"
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.url]
|
||||
version = "2"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = [
|
||||
"v7",
|
||||
"serde",
|
||||
]
|
||||
Generated
+27
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "cxcloud-bot"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "bot-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
cxcloud-common = { workspace = true }
|
||||
cxcloud-proto = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tonic = { workspace = true }
|
||||
prost = { workspace = true }
|
||||
prost-types = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
url = "2"
|
||||
@@ -0,0 +1,107 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Per-tool circuit breaker. Opens after `threshold` consecutive failures.
|
||||
pub struct CircuitBreaker {
|
||||
breakers: Mutex<HashMap<String, BreakerState>>,
|
||||
threshold: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BreakerState {
|
||||
failures: u32,
|
||||
open: bool,
|
||||
}
|
||||
|
||||
impl CircuitBreaker {
|
||||
pub fn new(threshold: u32) -> Self {
|
||||
Self {
|
||||
breakers: Mutex::new(HashMap::new()),
|
||||
threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool's circuit is open (should not be called).
|
||||
pub fn is_open(&self, tool: &str) -> bool {
|
||||
self.breakers
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(tool)
|
||||
.map(|b| b.open)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Record a success, resetting the failure counter.
|
||||
pub fn record_success(&self, tool: &str) {
|
||||
let mut breakers = self.breakers.lock().unwrap();
|
||||
breakers.insert(
|
||||
tool.to_string(),
|
||||
BreakerState {
|
||||
failures: 0,
|
||||
open: false,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Record a failure. Opens the circuit if threshold is reached.
|
||||
pub fn record_failure(&self, tool: &str) {
|
||||
let mut breakers = self.breakers.lock().unwrap();
|
||||
let state = breakers.entry(tool.to_string()).or_insert(BreakerState {
|
||||
failures: 0,
|
||||
open: false,
|
||||
});
|
||||
state.failures += 1;
|
||||
if state.failures >= self.threshold {
|
||||
state.open = true;
|
||||
tracing::warn!(tool, failures = state.failures, "Circuit breaker opened");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_tool_is_closed() {
|
||||
let cb = CircuitBreaker::new(3);
|
||||
assert!(!cb.is_open("unknown-tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opens_after_threshold_failures() {
|
||||
let cb = CircuitBreaker::new(3);
|
||||
cb.record_failure("flaky");
|
||||
assert!(!cb.is_open("flaky"));
|
||||
cb.record_failure("flaky");
|
||||
assert!(!cb.is_open("flaky"));
|
||||
cb.record_failure("flaky");
|
||||
assert!(cb.is_open("flaky"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn success_resets_failures() {
|
||||
let cb = CircuitBreaker::new(3);
|
||||
cb.record_failure("tool");
|
||||
cb.record_failure("tool");
|
||||
cb.record_success("tool");
|
||||
cb.record_failure("tool");
|
||||
assert!(!cb.is_open("tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn independent_tools_are_independent() {
|
||||
let cb = CircuitBreaker::new(2);
|
||||
cb.record_failure("a");
|
||||
cb.record_failure("a");
|
||||
assert!(cb.is_open("a"));
|
||||
assert!(!cb.is_open("b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn threshold_of_one_opens_immediately() {
|
||||
let cb = CircuitBreaker::new(1);
|
||||
cb.record_failure("tool");
|
||||
assert!(cb.is_open("tool"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Execute tool calls, respecting dependency order.
|
||||
pub async fn execute_tool_calls(
|
||||
registry: &ToolRegistry,
|
||||
tool_calls: &[serde_json::Value],
|
||||
) -> Vec<serde_json::Value> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for call in tool_calls {
|
||||
let tool_name = call
|
||||
.get("tool_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let params = call
|
||||
.get("parameters")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
// Check circuit breaker
|
||||
if registry.circuit_breaker.is_open(tool_name) {
|
||||
error!(tool = tool_name, "Circuit breaker open, skipping");
|
||||
results.push(serde_json::json!({
|
||||
"tool_call_id": call.get("id").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"tool_name": tool_name,
|
||||
"success": false,
|
||||
"error_message": "Circuit breaker open",
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let result = registry.execute(tool_name, ¶ms).await;
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
match result {
|
||||
Ok(output) => {
|
||||
registry.circuit_breaker.record_success(tool_name);
|
||||
info!(tool = tool_name, duration_ms, "Tool execution succeeded");
|
||||
results.push(serde_json::json!({
|
||||
"tool_call_id": call.get("id").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"tool_name": tool_name,
|
||||
"success": true,
|
||||
"output": output,
|
||||
"duration_ms": duration_ms,
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
registry.circuit_breaker.record_failure(tool_name);
|
||||
error!(tool = tool_name, error = %e, "Tool execution failed");
|
||||
results.push(serde_json::json!({
|
||||
"tool_call_id": call.get("id").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"tool_name": tool_name,
|
||||
"success": false,
|
||||
"error_message": e.to_string(),
|
||||
"duration_ms": duration_ms,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tonic::{Request, Response, Status};
|
||||
use tracing::info;
|
||||
|
||||
use cxcloud_proto::bot::{
|
||||
bot_service_server::{BotService, BotServiceServer},
|
||||
ExecutionRequest, ExecutionResponse, ListToolsRequest, ListToolsResponse, ToolInfo,
|
||||
ToolResult,
|
||||
};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub struct BotServiceImpl {
|
||||
state: Arc<AppState>,
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl BotService for BotServiceImpl {
|
||||
async fn execute(
|
||||
&self,
|
||||
request: Request<ExecutionRequest>,
|
||||
) -> Result<Response<ExecutionResponse>, Status> {
|
||||
let req = request.into_inner();
|
||||
info!(plan_id = %req.plan_id, tools = req.tool_calls.len(), "Execute request via gRPC");
|
||||
|
||||
let tool_calls: Vec<serde_json::Value> = req
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
serde_json::json!({
|
||||
"id": tc.id,
|
||||
"tool_name": tc.tool_name,
|
||||
"parameters": tc.parameters.as_ref().map(|p| format!("{:?}", p)).unwrap_or_default(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = crate::executor::execute_tool_calls(&self.state.registry, &tool_calls).await;
|
||||
|
||||
let tool_results: Vec<ToolResult> = results
|
||||
.iter()
|
||||
.map(|r| ToolResult {
|
||||
tool_call_id: r.get("tool_call_id").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
tool_name: r.get("tool_name").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
success: r.get("success").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
output: None,
|
||||
error_message: r.get("error_message").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
error_code: String::new(),
|
||||
retries_used: 0,
|
||||
duration_ms: r.get("duration_ms").and_then(|v| v.as_i64()).unwrap_or(0),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let all_success = tool_results.iter().all(|r| r.success);
|
||||
|
||||
Ok(Response::new(ExecutionResponse {
|
||||
execution_id: uuid::Uuid::now_v7().to_string(),
|
||||
results: tool_results,
|
||||
status: if all_success { 1 } else { 2 },
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
duration_ms: 0,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn list_tools(
|
||||
&self,
|
||||
_request: Request<ListToolsRequest>,
|
||||
) -> Result<Response<ListToolsResponse>, Status> {
|
||||
let tools: Vec<ToolInfo> = self
|
||||
.state
|
||||
.registry
|
||||
.list_tools()
|
||||
.iter()
|
||||
.map(|t| ToolInfo {
|
||||
name: t.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
description: t.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
category: t.get("category").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
available: t.get("available").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
schema: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Response::new(ListToolsResponse { tools }))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve(state: Arc<AppState>, port: u16) -> anyhow::Result<()> {
|
||||
let addr = format!("0.0.0.0:{port}").parse()?;
|
||||
let service = BotServiceImpl { state };
|
||||
|
||||
info!(port, "Bot gRPC server starting");
|
||||
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(BotServiceServer::new(service))
|
||||
.serve(addr)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
+119
@@ -0,0 +1,119 @@
|
||||
mod circuit_breaker;
|
||||
mod executor;
|
||||
mod grpc_server;
|
||||
pub mod tools;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use cxcloud_common::{
|
||||
config::{env_or, env_port},
|
||||
health::HealthResponse,
|
||||
telemetry,
|
||||
};
|
||||
|
||||
use tools::registry::ToolRegistry;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub registry: Arc<ToolRegistry>,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let otel_endpoint = env_or("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317");
|
||||
let log_level = env_or("LOG_LEVEL", "info");
|
||||
telemetry::init("bot-service", &otel_endpoint, &log_level);
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
let registry = ToolRegistry::new(http_client.clone());
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
registry: Arc::new(registry),
|
||||
http_client,
|
||||
});
|
||||
|
||||
// Spawn gRPC server
|
||||
let grpc_port = env_port("BOT_GRPC_PORT", 50053);
|
||||
let grpc_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = grpc_server::serve(grpc_state, grpc_port).await {
|
||||
tracing::error!(error = %e, "gRPC server failed");
|
||||
}
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/ready", get(ready))
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/execute", post(execute))
|
||||
.with_state(state);
|
||||
|
||||
let port = env_port("BOT_HTTP_PORT", 8002);
|
||||
info!(port, grpc_port, "Bot service starting");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}")).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
telemetry::shutdown();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health() -> Json<HealthResponse> {
|
||||
Json(HealthResponse::healthy("bot-service"))
|
||||
}
|
||||
|
||||
async fn ready(State(state): State<Arc<AppState>>) -> StatusCode {
|
||||
if state.registry.tool_count() > 0 {
|
||||
StatusCode::OK
|
||||
} else {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
async fn metrics(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"service": "bot-service",
|
||||
"registered_tools": state.registry.tool_count(),
|
||||
"executions": 0,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ExecutePayload {
|
||||
plan_id: String,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
plan: Option<String>,
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<ExecutePayload>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let results = executor::execute_tool_calls(&state.registry, &body.tool_calls).await;
|
||||
|
||||
Json(serde_json::json!({
|
||||
"execution_id": uuid::Uuid::now_v7().to_string(),
|
||||
"plan_id": body.plan_id,
|
||||
"results": results,
|
||||
"status": if results.iter().all(|r| r.get("success").and_then(|v| v.as_bool()).unwrap_or(false)) {
|
||||
"SUCCESS"
|
||||
} else {
|
||||
"PARTIAL_FAILURE"
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
use anyhow::{bail, Result};
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::info;
|
||||
|
||||
const DATA_DIR: &str = "/data";
|
||||
const MAX_READ_SIZE: u64 = 10 * 1024 * 1024; // 10 MB
|
||||
|
||||
/// Read a file from the sandboxed /data directory.
|
||||
pub async fn read(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let path_str = params
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("path parameter is required"))?;
|
||||
|
||||
let path = sanitize_path(path_str)?;
|
||||
info!(path = %path.display(), "Reading file");
|
||||
|
||||
let metadata = tokio::fs::metadata(&path).await?;
|
||||
if metadata.len() > MAX_READ_SIZE {
|
||||
bail!("File too large: {} bytes (max {})", metadata.len(), MAX_READ_SIZE);
|
||||
}
|
||||
|
||||
let content = tokio::fs::read_to_string(&path).await?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"path": path_str,
|
||||
"size": metadata.len(),
|
||||
"content": content,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Write a file to the sandboxed /data directory.
|
||||
pub async fn write(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let path_str = params
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("path parameter is required"))?;
|
||||
|
||||
let content = params
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("content parameter is required"))?;
|
||||
|
||||
let path = sanitize_path(path_str)?;
|
||||
info!(path = %path.display(), size = content.len(), "Writing file");
|
||||
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
tokio::fs::write(&path, content).await?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"path": path_str,
|
||||
"size": content.len(),
|
||||
"written": true,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Sanitize and validate that the path stays within /data.
|
||||
fn sanitize_path(input: &str) -> Result<PathBuf> {
|
||||
let base = Path::new(DATA_DIR);
|
||||
let full = base.join(input.trim_start_matches('/'));
|
||||
let canonical_base = base.to_path_buf();
|
||||
|
||||
// Prevent path traversal
|
||||
if !full.starts_with(&canonical_base) {
|
||||
bail!("Path traversal detected: {input}");
|
||||
}
|
||||
|
||||
Ok(full)
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use anyhow::{bail, Result};
|
||||
use std::net::IpAddr;
|
||||
use tracing::info;
|
||||
|
||||
const MAX_RESPONSE_SIZE: usize = 1024 * 1024; // 1 MB
|
||||
|
||||
/// Execute an HTTP request tool call with SSRF protection.
|
||||
pub async fn execute(client: &reqwest::Client, params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let url = params
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("url parameter is required"))?;
|
||||
|
||||
let method = params
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET")
|
||||
.to_uppercase();
|
||||
|
||||
// SSRF protection: block private/localhost addresses
|
||||
check_ssrf(url)?;
|
||||
|
||||
info!(method, url, "Executing HTTP request");
|
||||
|
||||
let mut req = match method.as_str() {
|
||||
"GET" => client.get(url),
|
||||
"POST" => client.post(url),
|
||||
"PUT" => client.put(url),
|
||||
"PATCH" => client.patch(url),
|
||||
"DELETE" => client.delete(url),
|
||||
"HEAD" => client.head(url),
|
||||
_ => bail!("Unsupported HTTP method: {method}"),
|
||||
};
|
||||
|
||||
// Add headers
|
||||
if let Some(headers) = params.get("headers").and_then(|v| v.as_object()) {
|
||||
for (key, value) in headers {
|
||||
if let Some(val) = value.as_str() {
|
||||
req = req.header(key.as_str(), val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add body
|
||||
if let Some(body) = params.get("body") {
|
||||
req = req.json(body);
|
||||
}
|
||||
|
||||
let resp = req.send().await?;
|
||||
let status = resp.status().as_u16();
|
||||
let headers: std::collections::HashMap<String, String> = resp
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let body = resp.text().await?;
|
||||
let body_truncated = if body.len() > MAX_RESPONSE_SIZE {
|
||||
format!("{}... [truncated at {} bytes]", &body[..MAX_RESPONSE_SIZE], body.len())
|
||||
} else {
|
||||
body
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"status": status,
|
||||
"headers": headers,
|
||||
"body": body_truncated,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Check for SSRF: block requests to private/localhost addresses.
|
||||
fn check_ssrf(url: &str) -> Result<()> {
|
||||
let parsed = url::Url::parse(url).map_err(|_| anyhow::anyhow!("Invalid URL: {url}"))?;
|
||||
|
||||
if let Some(host) = parsed.host_str() {
|
||||
// Block localhost variants
|
||||
if host == "localhost" || host == "127.0.0.1" || host == "::1" || host == "0.0.0.0" {
|
||||
bail!("SSRF blocked: requests to localhost are not allowed");
|
||||
}
|
||||
|
||||
// Block private IP ranges
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => {
|
||||
if ipv4.is_private()
|
||||
|| ipv4.is_loopback()
|
||||
|| ipv4.is_link_local()
|
||||
|| ipv4.octets()[0] == 169 && ipv4.octets()[1] == 254
|
||||
{
|
||||
bail!("SSRF blocked: requests to private IP ranges are not allowed");
|
||||
}
|
||||
}
|
||||
IpAddr::V6(ipv6) => {
|
||||
if ipv6.is_loopback() {
|
||||
bail!("SSRF blocked: requests to loopback addresses are not allowed");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
pub mod file_ops;
|
||||
pub mod http_request;
|
||||
pub mod notification;
|
||||
pub mod registry;
|
||||
pub mod shell;
|
||||
@@ -0,0 +1,24 @@
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
|
||||
/// Send a notification.
|
||||
pub async fn execute(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let channel = params
|
||||
.get("channel")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
|
||||
let message = params
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("message parameter is required"))?;
|
||||
|
||||
info!(channel, message, "Sending notification");
|
||||
|
||||
// In a real implementation, this would dispatch to Slack, email, webhook, etc.
|
||||
Ok(serde_json::json!({
|
||||
"channel": channel,
|
||||
"sent": true,
|
||||
"message": message,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
use crate::circuit_breaker::CircuitBreaker;
|
||||
use anyhow::Result;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
pub http_client: reqwest::Client,
|
||||
pub circuit_breaker: CircuitBreaker,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(http_client: reqwest::Client) -> Self {
|
||||
Self {
|
||||
http_client,
|
||||
circuit_breaker: CircuitBreaker::new(3),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_count(&self) -> usize {
|
||||
5 // http_request, file_read, file_write, notification, shell_exec
|
||||
}
|
||||
|
||||
pub fn list_tools(&self) -> Vec<serde_json::Value> {
|
||||
vec![
|
||||
serde_json::json!({"name": "http_request", "description": "Make HTTP requests", "category": "http", "available": true}),
|
||||
serde_json::json!({"name": "file_read", "description": "Read files from /data", "category": "file", "available": true}),
|
||||
serde_json::json!({"name": "file_write", "description": "Write files to /data", "category": "file", "available": true}),
|
||||
serde_json::json!({"name": "notification", "description": "Send notifications", "category": "notification", "available": true}),
|
||||
serde_json::json!({"name": "shell_exec", "description": "Execute shell commands", "category": "system", "available": true}),
|
||||
]
|
||||
}
|
||||
|
||||
/// Dispatch execution to the appropriate tool handler.
|
||||
pub async fn execute(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
params: &serde_json::Value,
|
||||
) -> Result<serde_json::Value> {
|
||||
match tool_name {
|
||||
"http_request" => super::http_request::execute(&self.http_client, params).await,
|
||||
"file_read" => super::file_ops::read(params).await,
|
||||
"file_write" => super::file_ops::write(params).await,
|
||||
"notification" => super::notification::execute(params).await,
|
||||
"shell_exec" => super::shell::execute(params).await,
|
||||
_ => anyhow::bail!("Unknown tool: {tool_name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
|
||||
/// Execute a shell command.
|
||||
pub async fn execute(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let command = params
|
||||
.get("command")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("command parameter is required"))?;
|
||||
|
||||
let timeout_secs = params
|
||||
.get("timeout_seconds")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(30);
|
||||
|
||||
info!(command, timeout_secs, "Executing shell command");
|
||||
|
||||
let output = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
tokio::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.output(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Command timed out after {timeout_secs}s"))??;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
let exit_code = output.status.code().unwrap_or(-1);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"exit_code": exit_code,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"success": output.status.success(),
|
||||
}))
|
||||
}
|
||||
Reference in New Issue
Block a user