vendor: update cxos-vendor-cargo
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "d878deb8441897ecdd416011b49d2d2f6112e867"
|
||||
},
|
||||
"path_in_vcs": "crates/cxai-bridge"
|
||||
}
|
||||
Generated
+2071
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,84 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "cxai-bridge"
|
||||
version = "0.1.0"
|
||||
authors = ["CxAI-LLM <agent@cxai-studio.com>"]
|
||||
build = false
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
description = "FFI bridge for C++ interop, JSON-RPC proxy, and SignalR WebSocket client"
|
||||
homepage = "https://cxllm.io"
|
||||
readme = false
|
||||
license = "MIT"
|
||||
repository = "https://git.cxllm-studio.com/CxAI-LLM/CxAI.Rust"
|
||||
resolver = "2"
|
||||
|
||||
[lib]
|
||||
name = "cxai_bridge"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies.cxai-sdk]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.futures-util]
|
||||
version = "0.3"
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
"stream",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tokio-tungstenite]
|
||||
version = "0.26"
|
||||
features = ["rustls-tls-webpki-roots"]
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = [
|
||||
"v4",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[dev-dependencies.tokio]
|
||||
version = "1"
|
||||
features = [
|
||||
"full",
|
||||
"test-util",
|
||||
"macros",
|
||||
]
|
||||
Generated
+24
@@ -0,0 +1,24 @@
|
||||
[package]
|
||||
name = "cxai-bridge"
|
||||
description = "FFI bridge for C++ interop, JSON-RPC proxy, and SignalR WebSocket client"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cxai-sdk = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tokio-tungstenite = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util", "macros"] }
|
||||
@@ -0,0 +1,7 @@
|
||||
pub mod pubsub;
|
||||
pub mod rpc;
|
||||
pub mod signalr;
|
||||
|
||||
pub use pubsub::PubSubClient;
|
||||
pub use rpc::RpcClient;
|
||||
pub use signalr::SignalRClient;
|
||||
@@ -0,0 +1,81 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{RwLock, broadcast};
|
||||
use tracing::debug;
|
||||
|
||||
type MessageHandler = broadcast::Sender<serde_json::Value>;
|
||||
|
||||
/// In-process pub/sub hub matching the .NET PubSubService.
|
||||
pub struct PubSubClient {
|
||||
topics: Arc<RwLock<HashMap<String, MessageHandler>>>,
|
||||
channel_capacity: usize,
|
||||
}
|
||||
|
||||
impl PubSubClient {
|
||||
pub fn new(channel_capacity: usize) -> Self {
|
||||
Self {
|
||||
topics: Arc::new(RwLock::new(HashMap::new())),
|
||||
channel_capacity,
|
||||
}
|
||||
}
|
||||
|
||||
/// Subscribe to a topic, returning a receiver for messages.
|
||||
pub async fn subscribe(&self, topic: &str) -> broadcast::Receiver<serde_json::Value> {
|
||||
let mut topics = self.topics.write().await;
|
||||
let sender = topics.entry(topic.to_string()).or_insert_with(|| {
|
||||
debug!(topic, "Creating new topic channel");
|
||||
let (tx, _) = broadcast::channel(self.channel_capacity);
|
||||
tx
|
||||
});
|
||||
sender.subscribe()
|
||||
}
|
||||
|
||||
/// Publish a message to a topic.
|
||||
pub async fn publish(&self, topic: &str, message: serde_json::Value) -> usize {
|
||||
let topics = self.topics.read().await;
|
||||
if let Some(sender) = topics.get(topic) {
|
||||
match sender.send(message) {
|
||||
Ok(count) => {
|
||||
debug!(topic, count, "Published message");
|
||||
count
|
||||
}
|
||||
Err(_) => 0,
|
||||
}
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
/// List all active topics.
|
||||
pub async fn topics(&self) -> Vec<String> {
|
||||
self.topics.read().await.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PubSubClient {
|
||||
fn default() -> Self {
|
||||
Self::new(256)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn pubsub_roundtrip() {
|
||||
let hub = PubSubClient::default();
|
||||
let mut rx = hub.subscribe("test.topic").await;
|
||||
let msg = serde_json::json!({"hello": "world"});
|
||||
hub.publish("test.topic", msg.clone()).await;
|
||||
let received = rx.recv().await.unwrap();
|
||||
assert_eq!(received, msg);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pubsub_no_subscribers() {
|
||||
let hub = PubSubClient::default();
|
||||
let count = hub.publish("empty.topic", serde_json::json!({})).await;
|
||||
assert_eq!(count, 0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, error};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// JSON-RPC 2.0 client for the C++ bridge backend.
|
||||
pub struct RpcClient {
|
||||
http: reqwest::Client,
|
||||
endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JsonRpcRequest {
|
||||
jsonrpc: &'static str,
|
||||
method: String,
|
||||
params: Option<serde_json::Value>,
|
||||
id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct JsonRpcResponse {
|
||||
result: Option<serde_json::Value>,
|
||||
error: Option<JsonRpcError>,
|
||||
#[allow(dead_code)]
|
||||
id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct JsonRpcError {
|
||||
code: i64,
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl RpcClient {
|
||||
/// Create a new JSON-RPC client pointing to the C++ bridge backend.
|
||||
pub fn new(endpoint: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: reqwest::Client::new(),
|
||||
endpoint: endpoint.into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Call an RPC method with optional parameters.
|
||||
pub async fn call(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> Result<serde_json::Value, cxai_sdk::CxError> {
|
||||
let request_id = Uuid::new_v4().to_string();
|
||||
let rpc_req = JsonRpcRequest {
|
||||
jsonrpc: "2.0",
|
||||
method: method.to_string(),
|
||||
params,
|
||||
id: request_id.clone(),
|
||||
};
|
||||
|
||||
debug!(method, id = %request_id, "RPC call");
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.post(&self.endpoint)
|
||||
.json(&rpc_req)
|
||||
.send()
|
||||
.await
|
||||
.map_err(cxai_sdk::CxError::Http)?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(cxai_sdk::CxError::from_response(resp).await);
|
||||
}
|
||||
|
||||
let rpc_resp: JsonRpcResponse = resp.json().await.map_err(cxai_sdk::CxError::Http)?;
|
||||
|
||||
if let Some(err) = rpc_resp.error {
|
||||
error!(code = err.code, message = %err.message, "RPC error");
|
||||
return Err(cxai_sdk::CxError::Api {
|
||||
status: err.code as u16,
|
||||
message: err.message,
|
||||
});
|
||||
}
|
||||
|
||||
rpc_resp.result.ok_or_else(|| cxai_sdk::CxError::Api {
|
||||
status: 500,
|
||||
message: "RPC response missing result".into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Discover available modules and methods on the C++ server.
|
||||
pub async fn discover(&self) -> Result<serde_json::Value, cxai_sdk::CxError> {
|
||||
self.call("system.discover", None).await
|
||||
}
|
||||
|
||||
/// Health check the C++ backend.
|
||||
pub async fn health(&self) -> Result<serde_json::Value, cxai_sdk::CxError> {
|
||||
self.call("system.health", None).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rpc_client_constructs() {
|
||||
let client = RpcClient::new("http://localhost:9090/rpc");
|
||||
assert_eq!(client.endpoint, "http://localhost:9090/rpc");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_tungstenite::tungstenite::Message;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
/// SignalR WebSocket client for real-time bridge events.
|
||||
pub struct SignalRClient {
|
||||
url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct SignalRMessage {
|
||||
#[serde(rename = "type")]
|
||||
msg_type: u8,
|
||||
target: String,
|
||||
arguments: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SignalRIncoming {
|
||||
#[serde(rename = "type")]
|
||||
msg_type: u8,
|
||||
#[allow(dead_code)]
|
||||
target: Option<String>,
|
||||
arguments: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
impl SignalRClient {
|
||||
/// Create a new SignalR client for the bridge hub.
|
||||
pub fn new(base_url: impl Into<String>) -> Self {
|
||||
let base = base_url.into();
|
||||
let ws_url = base
|
||||
.replace("http://", "ws://")
|
||||
.replace("https://", "wss://");
|
||||
Self {
|
||||
url: format!("{ws_url}/hubs/bridge"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Connect and subscribe to a topic, returning a channel of messages.
|
||||
pub async fn subscribe(
|
||||
&self,
|
||||
topic: &str,
|
||||
) -> Result<mpsc::Receiver<serde_json::Value>, cxai_sdk::CxError> {
|
||||
let (tx, rx) = mpsc::channel(256);
|
||||
let url = self.url.clone();
|
||||
let topic = topic.to_string();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match tokio_tungstenite::connect_async(&url).await {
|
||||
Ok((ws_stream, _)) => {
|
||||
info!(url = %url, "SignalR connected");
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
|
||||
// Send SignalR handshake
|
||||
let handshake = "{\"protocol\":\"json\",\"version\":1}\x1e".to_string();
|
||||
if let Err(e) = write.send(Message::Text(handshake.into())).await {
|
||||
error!("Handshake failed: {e}");
|
||||
return;
|
||||
}
|
||||
|
||||
// Subscribe to topic
|
||||
let sub_msg = SignalRMessage {
|
||||
msg_type: 1,
|
||||
target: "Subscribe".into(),
|
||||
arguments: vec![serde_json::json!(topic)],
|
||||
};
|
||||
let payload = format!("{}\x1e", serde_json::to_string(&sub_msg).unwrap());
|
||||
if let Err(e) = write.send(Message::Text(payload.into())).await {
|
||||
error!("Subscribe failed: {e}");
|
||||
return;
|
||||
}
|
||||
debug!(topic = %topic, "Subscribed");
|
||||
|
||||
// Read messages
|
||||
while let Some(msg) = read.next().await {
|
||||
match msg {
|
||||
Ok(Message::Text(text)) => {
|
||||
for part in text.split('\x1e').filter(|s| !s.is_empty()) {
|
||||
if let Ok(incoming) =
|
||||
serde_json::from_str::<SignalRIncoming>(part)
|
||||
&& incoming.msg_type == 1
|
||||
&& let Some(args) = incoming.arguments
|
||||
{
|
||||
for arg in args {
|
||||
if tx.send(arg).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Message::Close(_)) => {
|
||||
info!("SignalR connection closed");
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("WebSocket error: {e}");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to connect: {e}");
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(rx)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn signalr_url_construction() {
|
||||
let client = SignalRClient::new("http://localhost:9100");
|
||||
assert_eq!(client.url, "ws://localhost:9100/hubs/bridge");
|
||||
|
||||
let client = SignalRClient::new("https://api.cxllm.io");
|
||||
assert_eq!(client.url, "wss://api.cxllm.io/hubs/bridge");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "d878deb8441897ecdd416011b49d2d2f6112e867"
|
||||
},
|
||||
"path_in_vcs": "crates/cxai-energy"
|
||||
}
|
||||
Generated
+1942
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,74 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "cxai-energy"
|
||||
version = "0.1.0"
|
||||
authors = ["CxAI-LLM <agent@cxai-studio.com>"]
|
||||
build = false
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
description = "ERCOT energy market client — DAM/RT prices, demand, wind/solar, SCADA, forecasting"
|
||||
homepage = "https://cxllm.io"
|
||||
readme = false
|
||||
license = "MIT"
|
||||
repository = "https://git.cxllm-studio.com/CxAI-LLM/CxAI.Rust"
|
||||
resolver = "2"
|
||||
|
||||
[lib]
|
||||
name = "cxai_energy"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.cxai-sdk]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
"stream",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dev-dependencies.tokio]
|
||||
version = "1"
|
||||
features = [
|
||||
"full",
|
||||
"test-util",
|
||||
"macros",
|
||||
]
|
||||
Generated
+22
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "cxai-energy"
|
||||
description = "ERCOT energy market client — DAM/RT prices, demand, wind/solar, SCADA, forecasting"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cxai-sdk = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util", "macros"] }
|
||||
@@ -0,0 +1,63 @@
|
||||
use cxai_sdk::models::energy::*;
|
||||
use tracing::debug;
|
||||
|
||||
/// High-level ERCOT energy market client.
|
||||
pub struct ErcotClient {
|
||||
platform: cxai_sdk::CxPlatform,
|
||||
}
|
||||
|
||||
impl ErcotClient {
|
||||
pub fn new(platform: cxai_sdk::CxPlatform) -> Self {
|
||||
Self { platform }
|
||||
}
|
||||
|
||||
/// Get day-ahead market prices for a settlement point.
|
||||
pub async fn dam_prices(&self, settlement_point: &str) -> cxai_sdk::CxResult<Vec<DamPrice>> {
|
||||
debug!(settlement_point, "Fetching DAM prices");
|
||||
self.platform.energy().dam_prices(settlement_point).await
|
||||
}
|
||||
|
||||
/// Get real-time prices for a settlement point.
|
||||
pub async fn rt_prices(&self, settlement_point: &str) -> cxai_sdk::CxResult<Vec<RtPrice>> {
|
||||
debug!(settlement_point, "Fetching RT prices");
|
||||
self.platform.energy().rt_prices(settlement_point).await
|
||||
}
|
||||
|
||||
/// Get current system demand.
|
||||
pub async fn system_demand(&self) -> cxai_sdk::CxResult<SystemDemand> {
|
||||
self.platform.energy().system_demand().await
|
||||
}
|
||||
|
||||
/// Get current wind generation.
|
||||
pub async fn wind(&self) -> cxai_sdk::CxResult<WindGeneration> {
|
||||
self.platform.energy().wind_generation().await
|
||||
}
|
||||
|
||||
/// Get current solar generation.
|
||||
pub async fn solar(&self) -> cxai_sdk::CxResult<SolarGeneration> {
|
||||
self.platform.energy().solar_generation().await
|
||||
}
|
||||
|
||||
/// Get full market summary.
|
||||
pub async fn market_summary(&self) -> cxai_sdk::CxResult<MarketSummary> {
|
||||
self.platform.energy().market_summary().await
|
||||
}
|
||||
|
||||
/// Get ML prediction for energy data.
|
||||
pub async fn predict(
|
||||
&self,
|
||||
prediction_type: &str,
|
||||
model: &str,
|
||||
target_date: &str,
|
||||
target_hour: u8,
|
||||
) -> cxai_sdk::CxResult<ErcotMlPrediction> {
|
||||
debug!(
|
||||
prediction_type,
|
||||
model, target_date, target_hour, "ERCOT predict"
|
||||
);
|
||||
self.platform
|
||||
.energy()
|
||||
.predict(prediction_type, model, target_date, target_hour)
|
||||
.await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
pub mod ercot;
|
||||
pub mod market;
|
||||
pub mod scada;
|
||||
|
||||
pub use ercot::ErcotClient;
|
||||
pub use market::MarketAnalyzer;
|
||||
pub use scada::ScadaClient;
|
||||
@@ -0,0 +1,130 @@
|
||||
use cxai_sdk::models::energy::*;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Market analysis utilities for ERCOT data.
|
||||
pub struct MarketAnalyzer;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SpreadAnalysis {
|
||||
pub avg_dam: f64,
|
||||
pub avg_rt: f64,
|
||||
pub spread: f64,
|
||||
pub max_spread: f64,
|
||||
pub arbitrage_opportunities: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RenewableMix {
|
||||
pub wind_mw: f64,
|
||||
pub solar_mw: f64,
|
||||
pub total_renewable_mw: f64,
|
||||
pub demand_mw: f64,
|
||||
pub renewable_percentage: f64,
|
||||
}
|
||||
|
||||
impl MarketAnalyzer {
|
||||
/// Calculate DAM/RT price spread.
|
||||
pub fn spread(dam_prices: &[DamPrice], rt_prices: &[RtPrice]) -> SpreadAnalysis {
|
||||
let avg_dam = if dam_prices.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
dam_prices.iter().map(|p| p.price).sum::<f64>() / dam_prices.len() as f64
|
||||
};
|
||||
|
||||
let avg_rt = if rt_prices.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
rt_prices.iter().map(|p| p.price).sum::<f64>() / rt_prices.len() as f64
|
||||
};
|
||||
|
||||
let spreads: Vec<f64> = dam_prices
|
||||
.iter()
|
||||
.zip(rt_prices.iter())
|
||||
.map(|(d, r)| (d.price - r.price).abs())
|
||||
.collect();
|
||||
|
||||
let max_spread = spreads.iter().cloned().fold(0.0_f64, f64::max);
|
||||
let arbitrage_opportunities = spreads.iter().filter(|&&s| s > 10.0).count();
|
||||
|
||||
SpreadAnalysis {
|
||||
avg_dam,
|
||||
avg_rt,
|
||||
spread: avg_dam - avg_rt,
|
||||
max_spread,
|
||||
arbitrage_opportunities,
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate renewable energy mix.
|
||||
pub fn renewable_mix(
|
||||
wind: &WindGeneration,
|
||||
solar: &SolarGeneration,
|
||||
demand: &SystemDemand,
|
||||
) -> RenewableMix {
|
||||
let total = wind.generation_mw + solar.generation_mw;
|
||||
let pct = if demand.demand_mw > 0.0 {
|
||||
(total / demand.demand_mw) * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
RenewableMix {
|
||||
wind_mw: wind.generation_mw,
|
||||
solar_mw: solar.generation_mw,
|
||||
total_renewable_mw: total,
|
||||
demand_mw: demand.demand_mw,
|
||||
renewable_percentage: pct,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::{NaiveDate, Utc};
|
||||
|
||||
#[test]
|
||||
fn spread_calculation() {
|
||||
let dam = vec![DamPrice {
|
||||
settlement_point: "HB_HOUSTON".into(),
|
||||
delivery_date: NaiveDate::from_ymd_opt(2026, 4, 20).unwrap(),
|
||||
hour_ending: 14,
|
||||
price: 45.0,
|
||||
timestamp: None,
|
||||
}];
|
||||
let rt = vec![RtPrice {
|
||||
settlement_point: "HB_HOUSTON".into(),
|
||||
price: 52.0,
|
||||
interval: "15min".into(),
|
||||
timestamp: Utc::now(),
|
||||
}];
|
||||
|
||||
let analysis = MarketAnalyzer::spread(&dam, &rt);
|
||||
assert_eq!(analysis.avg_dam, 45.0);
|
||||
assert_eq!(analysis.avg_rt, 52.0);
|
||||
assert!((analysis.spread - (-7.0)).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn renewable_mix_calculation() {
|
||||
let wind = WindGeneration {
|
||||
generation_mw: 15000.0,
|
||||
capacity_mw: 40000.0,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let solar = SolarGeneration {
|
||||
generation_mw: 8000.0,
|
||||
capacity_mw: 20000.0,
|
||||
timestamp: Utc::now(),
|
||||
};
|
||||
let demand = SystemDemand {
|
||||
demand_mw: 60000.0,
|
||||
timestamp: Utc::now(),
|
||||
forecast_mw: None,
|
||||
};
|
||||
|
||||
let mix = MarketAnalyzer::renewable_mix(&wind, &solar, &demand);
|
||||
assert_eq!(mix.total_renewable_mw, 23000.0);
|
||||
assert!((mix.renewable_percentage - 38.333333333333336).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
use cxai_sdk::models::energy::ScadaSnapshot;
|
||||
use tracing::debug;
|
||||
|
||||
/// SCADA telemetry client for battery storage systems.
|
||||
pub struct ScadaClient {
|
||||
platform: cxai_sdk::CxPlatform,
|
||||
}
|
||||
|
||||
impl ScadaClient {
|
||||
pub fn new(platform: cxai_sdk::CxPlatform) -> Self {
|
||||
Self { platform }
|
||||
}
|
||||
|
||||
/// Get the latest SCADA snapshot for a device.
|
||||
pub async fn snapshot(&self, device_id: &str) -> cxai_sdk::CxResult<ScadaSnapshot> {
|
||||
debug!(device_id, "Fetching SCADA snapshot");
|
||||
self.platform.energy().scada_snapshot(device_id).await
|
||||
}
|
||||
|
||||
/// Check if a battery is within safe operating parameters.
|
||||
pub fn is_safe(snapshot: &ScadaSnapshot) -> bool {
|
||||
snapshot.soc_percent >= 5.0
|
||||
&& snapshot.soc_percent <= 95.0
|
||||
&& snapshot.temperature_c < 55.0
|
||||
&& snapshot.soh_percent > 70.0
|
||||
}
|
||||
|
||||
/// Calculate charge/discharge efficiency from a snapshot.
|
||||
pub fn efficiency(snapshot: &ScadaSnapshot) -> f64 {
|
||||
if snapshot.voltage_v == 0.0 || snapshot.current_a == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
let power_calc = snapshot.voltage_v * snapshot.current_a / 1000.0;
|
||||
if power_calc == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
(snapshot.power_kw / power_calc).abs().min(1.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use chrono::Utc;
|
||||
|
||||
fn test_snapshot() -> ScadaSnapshot {
|
||||
ScadaSnapshot {
|
||||
device_id: "BESS-001".into(),
|
||||
soc_percent: 65.0,
|
||||
soh_percent: 92.0,
|
||||
power_kw: 450.0,
|
||||
voltage_v: 800.0,
|
||||
current_a: 580.0,
|
||||
temperature_c: 32.0,
|
||||
timestamp: Utc::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn safe_battery() {
|
||||
assert!(ScadaClient::is_safe(&test_snapshot()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsafe_temperature() {
|
||||
let mut s = test_snapshot();
|
||||
s.temperature_c = 60.0;
|
||||
assert!(!ScadaClient::is_safe(&s));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn efficiency_calculation() {
|
||||
let s = test_snapshot();
|
||||
let eff = ScadaClient::efficiency(&s);
|
||||
assert!(eff > 0.0 && eff <= 1.0);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "d878deb8441897ecdd416011b49d2d2f6112e867"
|
||||
},
|
||||
"path_in_vcs": "crates/cxai-ml"
|
||||
}
|
||||
Generated
+1942
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,74 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "cxai-ml"
|
||||
version = "0.1.0"
|
||||
authors = ["CxAI-LLM <agent@cxai-studio.com>"]
|
||||
build = false
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
description = "ML inference client — model catalog, prediction, anomaly detection, regime classification"
|
||||
homepage = "https://cxllm.io"
|
||||
readme = false
|
||||
license = "MIT"
|
||||
repository = "https://git.cxllm-studio.com/CxAI-LLM/CxAI.Rust"
|
||||
resolver = "2"
|
||||
|
||||
[lib]
|
||||
name = "cxai_ml"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.cxai-sdk]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
"stream",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dev-dependencies.tokio]
|
||||
version = "1"
|
||||
features = [
|
||||
"full",
|
||||
"test-util",
|
||||
"macros",
|
||||
]
|
||||
Generated
+22
@@ -0,0 +1,22 @@
|
||||
[package]
|
||||
name = "cxai-ml"
|
||||
description = "ML inference client — model catalog, prediction, anomaly detection, regime classification"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cxai-sdk = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util", "macros"] }
|
||||
@@ -0,0 +1,315 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The 26-model catalog across 8 domains, matching the .NET ModelCatalogService.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ModelInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub domain: String,
|
||||
pub model_type: String,
|
||||
pub description: String,
|
||||
pub status: ModelStatus,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ModelStatus {
|
||||
Active,
|
||||
Training,
|
||||
Deprecated,
|
||||
Offline,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DomainInfo {
|
||||
pub name: String,
|
||||
pub model_count: usize,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
pub struct ModelCatalog;
|
||||
|
||||
impl ModelCatalog {
|
||||
/// Built-in model definitions across all domains.
|
||||
pub fn all_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
// Price Forecasting
|
||||
model(
|
||||
"dam-xgb",
|
||||
"DAM XGBoost",
|
||||
"price_forecasting",
|
||||
"xgboost",
|
||||
"Day-ahead market price prediction",
|
||||
),
|
||||
model(
|
||||
"dam-lstm",
|
||||
"DAM LSTM",
|
||||
"price_forecasting",
|
||||
"lstm",
|
||||
"DAM price via LSTM neural network",
|
||||
),
|
||||
model(
|
||||
"dam-transformer",
|
||||
"DAM Transformer",
|
||||
"price_forecasting",
|
||||
"transformer",
|
||||
"DAM price via transformer model",
|
||||
),
|
||||
model(
|
||||
"rt-xgb",
|
||||
"RT XGBoost",
|
||||
"price_forecasting",
|
||||
"xgboost",
|
||||
"Real-time price prediction",
|
||||
),
|
||||
model(
|
||||
"rt-lstm",
|
||||
"RT LSTM",
|
||||
"price_forecasting",
|
||||
"lstm",
|
||||
"RT price via LSTM",
|
||||
),
|
||||
// Demand Forecasting
|
||||
model(
|
||||
"demand-prophet",
|
||||
"Demand Prophet",
|
||||
"demand_forecasting",
|
||||
"prophet",
|
||||
"System demand forecasting",
|
||||
),
|
||||
model(
|
||||
"demand-xgb",
|
||||
"Demand XGBoost",
|
||||
"demand_forecasting",
|
||||
"xgboost",
|
||||
"Demand via gradient boosting",
|
||||
),
|
||||
model(
|
||||
"demand-nn",
|
||||
"Demand Neural",
|
||||
"demand_forecasting",
|
||||
"neural_net",
|
||||
"Demand via neural network",
|
||||
),
|
||||
// Renewable Forecasting
|
||||
model(
|
||||
"wind-xgb",
|
||||
"Wind XGBoost",
|
||||
"renewable_forecasting",
|
||||
"xgboost",
|
||||
"Wind generation prediction",
|
||||
),
|
||||
model(
|
||||
"wind-lstm",
|
||||
"Wind LSTM",
|
||||
"renewable_forecasting",
|
||||
"lstm",
|
||||
"Wind via LSTM",
|
||||
),
|
||||
model(
|
||||
"solar-xgb",
|
||||
"Solar XGBoost",
|
||||
"renewable_forecasting",
|
||||
"xgboost",
|
||||
"Solar generation prediction",
|
||||
),
|
||||
model(
|
||||
"solar-lstm",
|
||||
"Solar LSTM",
|
||||
"renewable_forecasting",
|
||||
"lstm",
|
||||
"Solar via LSTM",
|
||||
),
|
||||
// Anomaly Detection
|
||||
model(
|
||||
"anomaly-iforest",
|
||||
"IsolationForest",
|
||||
"anomaly_detection",
|
||||
"isolation_forest",
|
||||
"Market anomaly detection",
|
||||
),
|
||||
model(
|
||||
"anomaly-autoencoder",
|
||||
"Autoencoder",
|
||||
"anomaly_detection",
|
||||
"autoencoder",
|
||||
"Deep anomaly detection",
|
||||
),
|
||||
// Regime Classification
|
||||
model(
|
||||
"regime-rf",
|
||||
"Regime RF",
|
||||
"regime_classification",
|
||||
"random_forest",
|
||||
"4-class price regime",
|
||||
),
|
||||
model(
|
||||
"regime-xgb",
|
||||
"Regime XGBoost",
|
||||
"regime_classification",
|
||||
"xgboost",
|
||||
"Regime via gradient boosting",
|
||||
),
|
||||
// Battery Optimization
|
||||
model(
|
||||
"bess-dispatch",
|
||||
"BESS Dispatch",
|
||||
"battery_optimization",
|
||||
"linear_program",
|
||||
"Battery dispatch optimization",
|
||||
),
|
||||
model(
|
||||
"bess-degradation",
|
||||
"BESS Degradation",
|
||||
"battery_optimization",
|
||||
"regression",
|
||||
"Battery degradation model",
|
||||
),
|
||||
model(
|
||||
"bess-revenue",
|
||||
"BESS Revenue",
|
||||
"battery_optimization",
|
||||
"mixed_integer",
|
||||
"Revenue maximization",
|
||||
),
|
||||
// Congestion Analysis
|
||||
model(
|
||||
"congestion-rf",
|
||||
"Congestion RF",
|
||||
"congestion_analysis",
|
||||
"random_forest",
|
||||
"Congestion prediction",
|
||||
),
|
||||
model(
|
||||
"congestion-gnn",
|
||||
"Congestion GNN",
|
||||
"congestion_analysis",
|
||||
"graph_neural_net",
|
||||
"Grid topology model",
|
||||
),
|
||||
// Risk Analysis
|
||||
model(
|
||||
"var-historical",
|
||||
"VaR Historical",
|
||||
"risk_analysis",
|
||||
"statistical",
|
||||
"Historical value-at-risk",
|
||||
),
|
||||
model(
|
||||
"var-monte-carlo",
|
||||
"VaR Monte Carlo",
|
||||
"risk_analysis",
|
||||
"simulation",
|
||||
"Monte Carlo VaR",
|
||||
),
|
||||
model(
|
||||
"cvar-parametric",
|
||||
"CVaR Parametric",
|
||||
"risk_analysis",
|
||||
"statistical",
|
||||
"Conditional VaR",
|
||||
),
|
||||
model(
|
||||
"volatility-garch",
|
||||
"GARCH Volatility",
|
||||
"risk_analysis",
|
||||
"garch",
|
||||
"Price volatility model",
|
||||
),
|
||||
model(
|
||||
"correlation-dcc",
|
||||
"DCC Correlation",
|
||||
"risk_analysis",
|
||||
"dcc_garch",
|
||||
"Dynamic conditional correlation",
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
/// List all 8 model domains.
|
||||
pub fn domains() -> Vec<DomainInfo> {
|
||||
vec![
|
||||
DomainInfo {
|
||||
name: "price_forecasting".into(),
|
||||
model_count: 5,
|
||||
description: "DAM and RT price prediction".into(),
|
||||
},
|
||||
DomainInfo {
|
||||
name: "demand_forecasting".into(),
|
||||
model_count: 3,
|
||||
description: "System demand forecasting".into(),
|
||||
},
|
||||
DomainInfo {
|
||||
name: "renewable_forecasting".into(),
|
||||
model_count: 4,
|
||||
description: "Wind and solar generation".into(),
|
||||
},
|
||||
DomainInfo {
|
||||
name: "anomaly_detection".into(),
|
||||
model_count: 2,
|
||||
description: "Market anomaly detection".into(),
|
||||
},
|
||||
DomainInfo {
|
||||
name: "regime_classification".into(),
|
||||
model_count: 2,
|
||||
description: "Price regime classification".into(),
|
||||
},
|
||||
DomainInfo {
|
||||
name: "battery_optimization".into(),
|
||||
model_count: 3,
|
||||
description: "BESS dispatch and revenue".into(),
|
||||
},
|
||||
DomainInfo {
|
||||
name: "congestion_analysis".into(),
|
||||
model_count: 2,
|
||||
description: "Grid congestion prediction".into(),
|
||||
},
|
||||
DomainInfo {
|
||||
name: "risk_analysis".into(),
|
||||
model_count: 5,
|
||||
description: "VaR, CVaR, and volatility".into(),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Get models for a specific domain.
|
||||
pub fn by_domain(domain: &str) -> Vec<ModelInfo> {
|
||||
Self::all_models()
|
||||
.into_iter()
|
||||
.filter(|m| m.domain == domain)
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
fn model(id: &str, name: &str, domain: &str, model_type: &str, desc: &str) -> ModelInfo {
|
||||
ModelInfo {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
domain: domain.into(),
|
||||
model_type: model_type.into(),
|
||||
description: desc.into(),
|
||||
status: ModelStatus::Active,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn catalog_has_26_models() {
|
||||
assert_eq!(ModelCatalog::all_models().len(), 26);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn catalog_has_8_domains() {
|
||||
assert_eq!(ModelCatalog::domains().len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn filter_by_domain() {
|
||||
let price_models = ModelCatalog::by_domain("price_forecasting");
|
||||
assert_eq!(price_models.len(), 5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
|
||||
/// High-level ML inference client wrapping the platform ML service.
|
||||
pub struct MlInferenceClient {
|
||||
platform: cxai_sdk::CxPlatform,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictionRequest {
|
||||
pub model: String,
|
||||
pub features: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PredictionResult {
|
||||
pub model: String,
|
||||
pub prediction: serde_json::Value,
|
||||
pub confidence: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AnomalyResult {
|
||||
pub is_anomaly: bool,
|
||||
pub score: f64,
|
||||
pub features: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RegimeResult {
|
||||
pub regime: String,
|
||||
pub class_id: u8,
|
||||
pub probabilities: Vec<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ForecastResult {
|
||||
pub settlement_point: String,
|
||||
pub hours_ahead: u32,
|
||||
pub predictions: Vec<HourlyForecast>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct HourlyForecast {
|
||||
pub hour: u8,
|
||||
pub price: f64,
|
||||
pub lower_bound: f64,
|
||||
pub upper_bound: f64,
|
||||
}
|
||||
|
||||
impl MlInferenceClient {
|
||||
pub fn new(platform: cxai_sdk::CxPlatform) -> Self {
|
||||
Self { platform }
|
||||
}
|
||||
|
||||
/// Single model prediction.
|
||||
pub async fn predict(
|
||||
&self,
|
||||
request: &PredictionRequest,
|
||||
) -> cxai_sdk::CxResult<PredictionResult> {
|
||||
debug!(model = %request.model, "ML predict");
|
||||
let raw = self
|
||||
.platform
|
||||
.ml()
|
||||
.predict(&request.model, request.features.clone())
|
||||
.await?;
|
||||
Ok(serde_json::from_value(raw)?)
|
||||
}
|
||||
|
||||
/// Detect anomalies in feature data.
|
||||
pub async fn detect_anomalies(
|
||||
&self,
|
||||
features: serde_json::Value,
|
||||
) -> cxai_sdk::CxResult<AnomalyResult> {
|
||||
debug!("ML detect anomalies");
|
||||
let raw = self.platform.ml().detect_anomalies(features).await?;
|
||||
Ok(serde_json::from_value(raw)?)
|
||||
}
|
||||
|
||||
/// Classify price regime.
|
||||
pub async fn classify_regime(
|
||||
&self,
|
||||
features: serde_json::Value,
|
||||
) -> cxai_sdk::CxResult<RegimeResult> {
|
||||
debug!("ML classify regime");
|
||||
let raw = self.platform.ml().classify_regime(features).await?;
|
||||
Ok(serde_json::from_value(raw)?)
|
||||
}
|
||||
|
||||
/// Price forecast for a settlement point.
|
||||
pub async fn price_forecast(
|
||||
&self,
|
||||
settlement_point: &str,
|
||||
hours_ahead: u32,
|
||||
) -> cxai_sdk::CxResult<ForecastResult> {
|
||||
debug!(settlement_point, hours_ahead, "ML price forecast");
|
||||
let raw = self
|
||||
.platform
|
||||
.ml()
|
||||
.price_forecast(settlement_point, hours_ahead)
|
||||
.await?;
|
||||
Ok(serde_json::from_value(raw)?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
pub mod catalog;
|
||||
pub mod inference;
|
||||
|
||||
pub use catalog::ModelCatalog;
|
||||
pub use inference::MlInferenceClient;
|
||||
@@ -0,0 +1,6 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "d878deb8441897ecdd416011b49d2d2f6112e867"
|
||||
},
|
||||
"path_in_vcs": "crates/cxai-sdk"
|
||||
}
|
||||
Generated
+1926
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,86 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "cxai-sdk"
|
||||
version = "0.1.0"
|
||||
authors = ["CxAI-LLM <agent@cxai-studio.com>"]
|
||||
build = false
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
description = "Core SDK for the CxAI Platform — typed HTTP clients, models, resilience, and configuration"
|
||||
homepage = "https://cxllm.io"
|
||||
readme = false
|
||||
license = "MIT"
|
||||
repository = "https://git.cxllm-studio.com/CxAI-LLM/CxAI.Rust"
|
||||
resolver = "2"
|
||||
|
||||
[lib]
|
||||
name = "cxai_sdk"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies.anyhow]
|
||||
version = "1"
|
||||
|
||||
[dependencies.async-trait]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.backon]
|
||||
version = "1"
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
"stream",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = [
|
||||
"v4",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[dev-dependencies.tokio]
|
||||
version = "1"
|
||||
features = [
|
||||
"full",
|
||||
"test-util",
|
||||
"macros",
|
||||
]
|
||||
Generated
+25
@@ -0,0 +1,25 @@
|
||||
[package]
|
||||
name = "cxai-sdk"
|
||||
description = "Core SDK for the CxAI Platform — typed HTTP clients, models, resilience, and configuration"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
authors.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
backon = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true, features = ["test-util", "macros"] }
|
||||
@@ -0,0 +1,66 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
use crate::models::common::HealthResponse;
|
||||
|
||||
/// App service client — application gateway, auth, chat, inference.
|
||||
#[derive(Clone)]
|
||||
pub struct AppClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl AppClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/app{}",
|
||||
self.config.service_url(self.config.app_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn health(&self) -> CxResult<HealthResponse> {
|
||||
let resp = self.http.get(self.url("/healthz")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn chat(
|
||||
&self,
|
||||
message: &str,
|
||||
session_id: Option<&str>,
|
||||
) -> CxResult<serde_json::Value> {
|
||||
let mut body = serde_json::json!({ "message": message });
|
||||
if let Some(sid) = session_id {
|
||||
body["sessionId"] = serde_json::json!(sid);
|
||||
}
|
||||
let resp = self.http.post(self.url("/chat")).json(&body).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn infer(
|
||||
&self,
|
||||
request: &crate::models::inference::CxInferenceRequest,
|
||||
) -> CxResult<crate::models::inference::InferenceResponse> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/infer"))
|
||||
.json(request)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,87 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
|
||||
/// Bridge service client — C++ IPC gateway, JSON-RPC, pub/sub.
|
||||
#[derive(Clone)]
|
||||
pub struct BridgeClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl BridgeClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/bridge{}",
|
||||
self.config.service_url(self.config.bridge_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
/// Invoke an RPC method on the C++ backend.
|
||||
pub async fn rpc(
|
||||
&self,
|
||||
method: &str,
|
||||
params: Option<serde_json::Value>,
|
||||
) -> CxResult<serde_json::Value> {
|
||||
let body = serde_json::json!({
|
||||
"method": method,
|
||||
"params": params,
|
||||
});
|
||||
let resp = self.http.post(self.url("/rpc")).json(&body).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Publish a message to a topic.
|
||||
pub async fn publish(&self, topic: &str, message: serde_json::Value) -> CxResult<()> {
|
||||
let body = serde_json::json!({
|
||||
"topic": topic,
|
||||
"message": message,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/publish"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Discover available C++ modules and methods.
|
||||
pub async fn discover(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/discover")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Get bridge connection status.
|
||||
pub async fn status(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/status")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Get RPC call metrics.
|
||||
pub async fn metrics(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/metrics")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
|
||||
/// Codegen service client — MCP server, code intelligence.
|
||||
#[derive(Clone)]
|
||||
pub struct CodegenClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl CodegenClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/codegen{}",
|
||||
self.config.service_url(self.config.codegen_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn analyze(&self, code: &str, language: &str) -> CxResult<serde_json::Value> {
|
||||
let body = serde_json::json!({
|
||||
"code": code,
|
||||
"language": language,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/analyze"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn generate(&self, prompt: &str, language: &str) -> CxResult<serde_json::Value> {
|
||||
let body = serde_json::json!({
|
||||
"prompt": prompt,
|
||||
"language": language,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/generate"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_resources(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/resources")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,94 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
use crate::models::commissioning::*;
|
||||
|
||||
/// Commissioning service client — BESS lifecycle management.
|
||||
#[derive(Clone)]
|
||||
pub struct CommissioningClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl CommissioningClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/commissioning{}",
|
||||
self.config.service_url(self.config.commissioning_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn list_projects(&self) -> CxResult<Vec<CommissioningProject>> {
|
||||
let resp = self.http.get(self.url("/projects")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn get_project(&self, id: &str) -> CxResult<CommissioningProject> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/projects/{id}")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_steps(&self, project_id: &str) -> CxResult<Vec<WorkflowStep>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/projects/{project_id}/steps")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_equipment(&self, project_id: &str) -> CxResult<Vec<Equipment>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/projects/{project_id}/equipment")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_incidents(&self, project_id: &str) -> CxResult<Vec<Incident>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/projects/{project_id}/incidents")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_compliance(&self, project_id: &str) -> CxResult<Vec<ComplianceRecord>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/projects/{project_id}/compliance")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
|
||||
/// CxDocs service client — document engine.
|
||||
#[derive(Clone)]
|
||||
pub struct CxDocsClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl CxDocsClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/cxdocs{}",
|
||||
self.config.service_url(self.config.cxdocs_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn list_documents(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/documents")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn get_document(&self, id: &str) -> CxResult<serde_json::Value> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/documents/{id}")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn generate(&self, template: &str, data: serde_json::Value) -> CxResult<Vec<u8>> {
|
||||
let body = serde_json::json!({
|
||||
"template": template,
|
||||
"data": data,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/generate"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.bytes().await?.to_vec())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
use crate::models::energy::*;
|
||||
|
||||
/// Energy service client — ERCOT market data, SCADA, ML predictions.
|
||||
#[derive(Clone)]
|
||||
pub struct EnergyClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl EnergyClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/energy{}",
|
||||
self.config.service_url(self.config.energy_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn dam_prices(&self, settlement_point: &str) -> CxResult<Vec<DamPrice>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url("/dam/prices"))
|
||||
.query(&[("settlementPoint", settlement_point)])
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn rt_prices(&self, settlement_point: &str) -> CxResult<Vec<RtPrice>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url("/rt/prices"))
|
||||
.query(&[("settlementPoint", settlement_point)])
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn system_demand(&self) -> CxResult<SystemDemand> {
|
||||
let resp = self.http.get(self.url("/demand")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn wind_generation(&self) -> CxResult<WindGeneration> {
|
||||
let resp = self.http.get(self.url("/wind")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn solar_generation(&self) -> CxResult<SolarGeneration> {
|
||||
let resp = self.http.get(self.url("/solar")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn market_summary(&self) -> CxResult<MarketSummary> {
|
||||
let resp = self.http.get(self.url("/summary")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn scada_snapshot(&self, device_id: &str) -> CxResult<ScadaSnapshot> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/scada/{device_id}")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn predict(
|
||||
&self,
|
||||
prediction_type: &str,
|
||||
model: &str,
|
||||
target_date: &str,
|
||||
target_hour: u8,
|
||||
) -> CxResult<ErcotMlPrediction> {
|
||||
let body = serde_json::json!({
|
||||
"predictionType": prediction_type,
|
||||
"model": model,
|
||||
"targetDate": target_date,
|
||||
"targetHour": target_hour,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/predict"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
use crate::models::forma::*;
|
||||
|
||||
/// Forma service client — Autodesk Forma/ACC integration.
|
||||
#[derive(Clone)]
|
||||
pub struct FormaClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl FormaClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/forma{}",
|
||||
self.config.service_url(self.config.forma_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn list_projects(&self, token: &str) -> CxResult<Vec<FormaProjectInfo>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url("/projects"))
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_elements(
|
||||
&self,
|
||||
token: &str,
|
||||
project_id: &str,
|
||||
) -> CxResult<Vec<FormaElementInfo>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/projects/{project_id}/elements")))
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn trigger_analysis(
|
||||
&self,
|
||||
token: &str,
|
||||
project_id: &str,
|
||||
analysis_type: &str,
|
||||
parameters: serde_json::Value,
|
||||
) -> CxResult<FormaAnalysisResult> {
|
||||
let body = serde_json::json!({
|
||||
"analysisType": analysis_type,
|
||||
"parameters": parameters,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url(&format!("/projects/{project_id}/analysis")))
|
||||
.bearer_auth(token)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn check_compliance(
|
||||
&self,
|
||||
token: &str,
|
||||
element_urn: &str,
|
||||
) -> CxResult<FormaComplianceResult> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/compliance/{element_urn}")))
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_issues(&self, token: &str, container_id: &str) -> CxResult<Vec<AccIssue>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/acc/issues/{container_id}")))
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_checklists(
|
||||
&self,
|
||||
token: &str,
|
||||
project_id: &str,
|
||||
) -> CxResult<Vec<AccChecklist>> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/acc/checklists/{project_id}")))
|
||||
.bearer_auth(token)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
|
||||
/// ML service client — model catalog, prediction, anomaly detection.
|
||||
#[derive(Clone)]
|
||||
pub struct MlClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl MlClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/ml{}",
|
||||
self.config.service_url(self.config.ml_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
/// Single model prediction.
|
||||
pub async fn predict(
|
||||
&self,
|
||||
model: &str,
|
||||
features: serde_json::Value,
|
||||
) -> CxResult<serde_json::Value> {
|
||||
let body = serde_json::json!({
|
||||
"model": model,
|
||||
"features": features,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/predict"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Batch prediction across multiple models.
|
||||
pub async fn predict_batch(
|
||||
&self,
|
||||
requests: Vec<serde_json::Value>,
|
||||
) -> CxResult<serde_json::Value> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/predict/batch"))
|
||||
.json(&requests)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Anomaly detection using IsolationForest.
|
||||
pub async fn detect_anomalies(
|
||||
&self,
|
||||
features: serde_json::Value,
|
||||
) -> CxResult<serde_json::Value> {
|
||||
let body = serde_json::json!({ "features": features });
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/detect-anomalies"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Classify price regime (4-class).
|
||||
pub async fn classify_regime(
|
||||
&self,
|
||||
features: serde_json::Value,
|
||||
) -> CxResult<serde_json::Value> {
|
||||
let body = serde_json::json!({ "features": features });
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/classify-regime"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// List all 26 models in the catalog.
|
||||
pub async fn list_models(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/models")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// List model domains.
|
||||
pub async fn list_domains(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/domains")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// ERCOT price forecast.
|
||||
pub async fn price_forecast(
|
||||
&self,
|
||||
settlement_point: &str,
|
||||
hours_ahead: u32,
|
||||
) -> CxResult<serde_json::Value> {
|
||||
let body = serde_json::json!({
|
||||
"settlementPoint": settlement_point,
|
||||
"hoursAhead": hours_ahead,
|
||||
});
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/intelligence/price-forecast"))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Backend health check.
|
||||
pub async fn backend_health(&self) -> CxResult<serde_json::Value> {
|
||||
let resp = self.http.get(self.url("/backend-health")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
mod app;
|
||||
mod bridge;
|
||||
mod codegen;
|
||||
mod commissioning;
|
||||
mod cxdocs;
|
||||
mod energy;
|
||||
mod forma;
|
||||
mod ml;
|
||||
mod nemo;
|
||||
mod nvidia;
|
||||
mod studio;
|
||||
mod supabase;
|
||||
|
||||
pub use app::AppClient;
|
||||
pub use bridge::BridgeClient;
|
||||
pub use codegen::CodegenClient;
|
||||
pub use commissioning::CommissioningClient;
|
||||
pub use cxdocs::CxDocsClient;
|
||||
pub use energy::EnergyClient;
|
||||
pub use forma::FormaClient;
|
||||
pub use ml::MlClient;
|
||||
pub use nemo::NemoClient;
|
||||
pub use nvidia::NvidiaClient;
|
||||
pub use studio::StudioClient;
|
||||
pub use supabase::SupabaseClient;
|
||||
@@ -0,0 +1,126 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
use crate::models::inference::*;
|
||||
use crate::models::nemo::*;
|
||||
|
||||
/// Nemo service client — AI agent orchestration.
|
||||
#[derive(Clone)]
|
||||
pub struct NemoClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl NemoClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/nemo{}",
|
||||
self.config.service_url(self.config.nemo_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn run_agent(&self, request: &AgentRunRequest) -> CxResult<AgentRunResponse> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/agent/run"))
|
||||
.json(request)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn chat(&self, request: &NemoChatRequest) -> CxResult<NemoChatResponse> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/chat"))
|
||||
.json(request)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_sessions(&self) -> CxResult<Vec<NemoSessionSummary>> {
|
||||
let resp = self.http.get(self.url("/sessions")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn get_session(&self, session_id: &str) -> CxResult<NemoSession> {
|
||||
let resp = self
|
||||
.http
|
||||
.get(self.url(&format!("/sessions/{session_id}")))
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_tools(&self) -> CxResult<Vec<ToolDefinition>> {
|
||||
let resp = self.http.get(self.url("/tools")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn execute_tool(
|
||||
&self,
|
||||
request: &ToolExecutionRequest,
|
||||
) -> CxResult<ToolExecutionResult> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/tools/execute"))
|
||||
.json(request)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_profiles(&self) -> CxResult<Vec<InferenceProfile>> {
|
||||
let resp = self.http.get(self.url("/profiles")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn route(&self, request: &RouteRequest) -> CxResult<RouteResponse> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(self.url("/route"))
|
||||
.json(request)
|
||||
.send()
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn metrics(&self) -> CxResult<NemoMetrics> {
|
||||
let resp = self.http.get(self.url("/metrics")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
use crate::models::ChatMessage;
|
||||
use crate::models::inference::*;
|
||||
|
||||
/// NVIDIA NIM inference client — chat, embedding, reranking, vision, guardrails.
|
||||
#[derive(Clone)]
|
||||
pub struct NvidiaClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl NvidiaClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
#[allow(clippy::misnamed_getters)]
|
||||
fn base_url(&self) -> &str {
|
||||
&self.config.nvidia_base_url
|
||||
}
|
||||
|
||||
fn api_key(&self) -> CxResult<&str> {
|
||||
self.config
|
||||
.nvidia_api_key
|
||||
.as_deref()
|
||||
.ok_or_else(|| CxError::Config("NVIDIA_API_KEY not set".into()))
|
||||
}
|
||||
|
||||
/// Run chat completion.
|
||||
pub async fn chat(&self, request: &InferenceRequest) -> CxResult<InferenceResponse> {
|
||||
let resp = self
|
||||
.http
|
||||
.post(format!("{}/chat/completions", self.base_url()))
|
||||
.bearer_auth(self.api_key()?)
|
||||
.json(request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Generate embeddings for text.
|
||||
pub async fn embed(&self, text: &str) -> CxResult<Vec<f32>> {
|
||||
let body = serde_json::json!({
|
||||
"input": text,
|
||||
"model": "nvidia/nv-embedqa-e5-v5"
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.post(format!("{}/embeddings", self.base_url()))
|
||||
.bearer_auth(self.api_key()?)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
|
||||
let data: serde_json::Value = resp.json().await?;
|
||||
let embedding = data["data"][0]["embedding"]
|
||||
.as_array()
|
||||
.ok_or_else(|| CxError::Api {
|
||||
status: 500,
|
||||
message: "missing embedding data".into(),
|
||||
})?
|
||||
.iter()
|
||||
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
||||
.collect();
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Rerank documents against a query.
|
||||
pub async fn rerank(&self, query: &str, documents: &[String]) -> CxResult<InferenceResponse> {
|
||||
let body = serde_json::json!({
|
||||
"model": "nvidia/nv-rerankqa-mistral-4b-v3",
|
||||
"query": { "text": query },
|
||||
"passages": documents.iter().map(|d| serde_json::json!({"text": d})).collect::<Vec<_>>()
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.post(format!("{}/ranking", self.base_url()))
|
||||
.bearer_auth(self.api_key()?)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Vision analysis on an image URL.
|
||||
pub async fn vision(&self, prompt: &str, image_url: &str) -> CxResult<InferenceResponse> {
|
||||
let request = InferenceRequest {
|
||||
model: "meta/llama-4-maverick-17b-128e-instruct".into(),
|
||||
messages: vec![ChatMessage {
|
||||
role: "user".into(),
|
||||
content: format!(
|
||||
r#"[{{"type":"text","text":"{prompt}"}},{{"type":"image_url","image_url":{{"url":"{image_url}"}}}}]"#
|
||||
),
|
||||
timestamp: None,
|
||||
}],
|
||||
temperature: Some(0.2),
|
||||
max_tokens: Some(1024),
|
||||
stream: Some(false),
|
||||
top_p: None,
|
||||
};
|
||||
|
||||
self.chat(&request).await
|
||||
}
|
||||
|
||||
/// Run guardrail safety check.
|
||||
pub async fn guardrail(&self, text: &str) -> CxResult<GuardrailResult> {
|
||||
let body = serde_json::json!({
|
||||
"model": "nvidia/llama-3.1-nemoguard-8b-content-safety",
|
||||
"messages": [{"role": "user", "content": text}]
|
||||
});
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.post(format!("{}/chat/completions", self.base_url()))
|
||||
.bearer_auth(self.api_key()?)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
use crate::models::common::*;
|
||||
use crate::models::studio::*;
|
||||
|
||||
/// Studio service client — platform hub, service registry, deployments.
|
||||
#[derive(Clone)]
|
||||
pub struct StudioClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl StudioClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn url(&self, path: &str) -> String {
|
||||
format!(
|
||||
"{}/api/studio{}",
|
||||
self.config.service_url(self.config.studio_port),
|
||||
path
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn health(&self) -> CxResult<HealthResponse> {
|
||||
let resp = self.http.get(self.url("/healthz")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn info(&self) -> CxResult<PlatformInfo> {
|
||||
let resp = self.http.get(self.url("/info")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_services(&self) -> CxResult<Vec<Service>> {
|
||||
let resp = self.http.get(self.url("/services")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_deployments(&self) -> CxResult<Vec<Deployment>> {
|
||||
let resp = self.http.get(self.url("/deployments")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
pub async fn list_conversations(&self) -> CxResult<Vec<Conversation>> {
|
||||
let resp = self.http.get(self.url("/conversations")).send().await?;
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::config::PlatformConfig;
|
||||
use crate::error::{CxError, CxResult};
|
||||
|
||||
/// Supabase client — PostgreSQL queries and Storage operations.
|
||||
#[derive(Clone)]
|
||||
pub struct SupabaseClient {
|
||||
http: reqwest::Client,
|
||||
config: Arc<PlatformConfig>,
|
||||
}
|
||||
|
||||
impl SupabaseClient {
|
||||
pub fn new(http: reqwest::Client, config: Arc<PlatformConfig>) -> Self {
|
||||
Self { http, config }
|
||||
}
|
||||
|
||||
fn base_url(&self) -> CxResult<&str> {
|
||||
self.config
|
||||
.supabase_url
|
||||
.as_deref()
|
||||
.ok_or_else(|| CxError::Config("SUPABASE_URL not set".into()))
|
||||
}
|
||||
|
||||
fn service_key(&self) -> CxResult<&str> {
|
||||
self.config
|
||||
.supabase_service_role_key
|
||||
.as_deref()
|
||||
.ok_or_else(|| CxError::Config("SUPABASE_SERVICE_ROLE_KEY not set".into()))
|
||||
}
|
||||
|
||||
/// Query rows from a table.
|
||||
pub async fn query<T: serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
table: &str,
|
||||
select: Option<&str>,
|
||||
filter: Option<&str>,
|
||||
) -> CxResult<Vec<T>> {
|
||||
let mut url = format!("{}/rest/v1/{}", self.base_url()?, table);
|
||||
let mut params = vec![];
|
||||
if let Some(s) = select {
|
||||
params.push(format!("select={s}"));
|
||||
}
|
||||
if let Some(f) = filter {
|
||||
params.push(f.to_string());
|
||||
}
|
||||
if !params.is_empty() {
|
||||
url = format!("{}?{}", url, params.join("&"));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.header("apikey", self.service_key()?)
|
||||
.header("Authorization", format!("Bearer {}", self.service_key()?))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(resp.json().await?)
|
||||
}
|
||||
|
||||
/// Insert a row.
|
||||
pub async fn insert<T: serde::Serialize + serde::de::DeserializeOwned>(
|
||||
&self,
|
||||
table: &str,
|
||||
data: &T,
|
||||
) -> CxResult<T> {
|
||||
let url = format!("{}/rest/v1/{}", self.base_url()?, table);
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("apikey", self.service_key()?)
|
||||
.header("Authorization", format!("Bearer {}", self.service_key()?))
|
||||
.header("Prefer", "return=representation")
|
||||
.json(data)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
|
||||
let mut items: Vec<T> = resp.json().await?;
|
||||
items.pop().ok_or_else(|| CxError::Api {
|
||||
status: 500,
|
||||
message: "empty insert response".into(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Delete rows matching a filter.
|
||||
pub async fn delete(&self, table: &str, filter: &str) -> CxResult<()> {
|
||||
let url = format!("{}/rest/v1/{}?{}", self.base_url()?, table, filter);
|
||||
let resp = self
|
||||
.http
|
||||
.delete(&url)
|
||||
.header("apikey", self.service_key()?)
|
||||
.header("Authorization", format!("Bearer {}", self.service_key()?))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Upload a file to Supabase Storage.
|
||||
pub async fn upload_file(
|
||||
&self,
|
||||
bucket: &str,
|
||||
path: &str,
|
||||
data: Vec<u8>,
|
||||
content_type: &str,
|
||||
) -> CxResult<()> {
|
||||
let url = format!("{}/storage/v1/object/{}/{}", self.base_url()?, bucket, path);
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("apikey", self.service_key()?)
|
||||
.header("Authorization", format!("Bearer {}", self.service_key()?))
|
||||
.header("Content-Type", content_type)
|
||||
.body(data)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a signed download URL.
|
||||
pub async fn signed_url(&self, bucket: &str, path: &str, expires_in: u64) -> CxResult<String> {
|
||||
let url = format!(
|
||||
"{}/storage/v1/object/sign/{}/{}",
|
||||
self.base_url()?,
|
||||
bucket,
|
||||
path
|
||||
);
|
||||
let body = serde_json::json!({ "expiresIn": expires_in });
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("apikey", self.service_key()?)
|
||||
.header("Authorization", format!("Bearer {}", self.service_key()?))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
return Err(CxError::from_response(resp).await);
|
||||
}
|
||||
|
||||
let data: serde_json::Value = resp.json().await?;
|
||||
data["signedURL"]
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
.ok_or_else(|| CxError::Api {
|
||||
status: 500,
|
||||
message: "missing signedURL".into(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,118 @@
|
||||
use crate::error::{CxError, CxResult};
|
||||
|
||||
/// Platform-wide configuration.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PlatformConfig {
|
||||
pub base_url: String,
|
||||
pub api_token: Option<String>,
|
||||
pub timeout_secs: u64,
|
||||
pub pool_size: usize,
|
||||
|
||||
// Service ports
|
||||
pub studio_port: u16,
|
||||
pub commissioning_port: u16,
|
||||
pub energy_port: u16,
|
||||
pub forma_port: u16,
|
||||
pub cxdocs_port: u16,
|
||||
pub app_port: u16,
|
||||
pub nemo_port: u16,
|
||||
pub codegen_port: u16,
|
||||
pub bridge_port: u16,
|
||||
pub ml_port: u16,
|
||||
|
||||
// NVIDIA
|
||||
pub nvidia_api_key: Option<String>,
|
||||
pub nvidia_base_url: String,
|
||||
|
||||
// Supabase
|
||||
pub supabase_url: Option<String>,
|
||||
pub supabase_service_role_key: Option<String>,
|
||||
|
||||
// ERCOT
|
||||
pub ercot_username: Option<String>,
|
||||
pub ercot_password: Option<String>,
|
||||
pub ercot_subscription_key: Option<String>,
|
||||
|
||||
// MongoDB
|
||||
pub mongodb_uri: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for PlatformConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
base_url: "http://localhost".into(),
|
||||
api_token: None,
|
||||
timeout_secs: 30,
|
||||
pool_size: 10,
|
||||
studio_port: 8080,
|
||||
commissioning_port: 8081,
|
||||
energy_port: 8082,
|
||||
forma_port: 8083,
|
||||
cxdocs_port: 8084,
|
||||
app_port: 8090,
|
||||
nemo_port: 18800,
|
||||
codegen_port: 18900,
|
||||
bridge_port: 9100,
|
||||
ml_port: 8085,
|
||||
nvidia_api_key: None,
|
||||
nvidia_base_url: "https://integrate.api.nvidia.com/v1".into(),
|
||||
supabase_url: None,
|
||||
supabase_service_role_key: None,
|
||||
ercot_username: None,
|
||||
ercot_password: None,
|
||||
ercot_subscription_key: None,
|
||||
mongodb_uri: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PlatformConfig {
|
||||
/// Load configuration from environment variables.
|
||||
pub fn from_env() -> CxResult<Self> {
|
||||
let env = |key: &str| std::env::var(key).ok();
|
||||
|
||||
Ok(Self {
|
||||
base_url: env("CXAI_BASE_URL").unwrap_or_else(|| "http://localhost".into()),
|
||||
api_token: env("CXAI_API_TOKEN"),
|
||||
timeout_secs: env("CXAI_TIMEOUT_SECS")
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(30),
|
||||
pool_size: env("CXAI_POOL_SIZE")
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(10),
|
||||
studio_port: parse_port("CXAI_STUDIO_PORT", 8080)?,
|
||||
commissioning_port: parse_port("CXAI_COMMISSIONING_PORT", 8081)?,
|
||||
energy_port: parse_port("CXAI_ENERGY_PORT", 8082)?,
|
||||
forma_port: parse_port("CXAI_FORMA_PORT", 8083)?,
|
||||
cxdocs_port: parse_port("CXAI_CXDOCS_PORT", 8084)?,
|
||||
app_port: parse_port("CXAI_APP_PORT", 8090)?,
|
||||
nemo_port: parse_port("CXAI_NEMO_PORT", 18800)?,
|
||||
codegen_port: parse_port("CXAI_CODEGEN_PORT", 18900)?,
|
||||
bridge_port: parse_port("CXAI_BRIDGE_PORT", 9100)?,
|
||||
ml_port: parse_port("CXAI_ML_PORT", 8085)?,
|
||||
nvidia_api_key: env("NVIDIA_API_KEY"),
|
||||
nvidia_base_url: env("NVIDIA_BASE_URL")
|
||||
.unwrap_or_else(|| "https://integrate.api.nvidia.com/v1".into()),
|
||||
supabase_url: env("SUPABASE_URL"),
|
||||
supabase_service_role_key: env("SUPABASE_SERVICE_ROLE_KEY"),
|
||||
ercot_username: env("ERCOT_USERNAME"),
|
||||
ercot_password: env("ERCOT_PASSWORD"),
|
||||
ercot_subscription_key: env("ERCOT_SUBSCRIPTION_KEY"),
|
||||
mongodb_uri: env("MONGODB_URI"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Build the full URL for a service.
|
||||
pub fn service_url(&self, port: u16) -> String {
|
||||
format!("{}:{}", self.base_url, port)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_port(key: &str, default: u16) -> CxResult<u16> {
|
||||
match std::env::var(key) {
|
||||
Ok(v) => v
|
||||
.parse()
|
||||
.map_err(|_| CxError::Config(format!("{key}: invalid port '{v}'"))),
|
||||
Err(_) => Ok(default),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Platform SDK error type.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum CxError {
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
|
||||
#[error("API error ({status}): {message}")]
|
||||
Api { status: u16, message: String },
|
||||
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
|
||||
#[error("Timeout after {0}s")]
|
||||
Timeout(u64),
|
||||
|
||||
#[error("Circuit breaker open for {service}")]
|
||||
CircuitOpen { service: String },
|
||||
}
|
||||
|
||||
pub type CxResult<T> = Result<T, CxError>;
|
||||
|
||||
impl CxError {
|
||||
/// Create from an HTTP response with non-2xx status.
|
||||
pub async fn from_response(response: reqwest::Response) -> Self {
|
||||
let status = response.status().as_u16();
|
||||
let message = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "unknown error".into());
|
||||
|
||||
match status {
|
||||
401 | 403 => Self::Unauthorized(message),
|
||||
404 => Self::NotFound(message),
|
||||
_ => Self::Api { status, message },
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,129 @@
|
||||
pub mod clients;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
pub mod models;
|
||||
pub mod resilience;
|
||||
|
||||
pub use config::PlatformConfig;
|
||||
pub use error::{CxError, CxResult};
|
||||
|
||||
use clients::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Central SDK handle providing access to all platform service clients.
|
||||
pub struct CxPlatform {
|
||||
config: Arc<PlatformConfig>,
|
||||
http: reqwest::Client,
|
||||
}
|
||||
|
||||
impl CxPlatform {
|
||||
/// Create a new platform SDK instance from configuration.
|
||||
pub fn new(config: PlatformConfig) -> CxResult<Self> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
if let Some(ref token) = config.api_token {
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {token}"))
|
||||
.map_err(|e| CxError::Config(format!("invalid token: {e}")))?,
|
||||
);
|
||||
}
|
||||
|
||||
let http = reqwest::Client::builder()
|
||||
.default_headers(headers)
|
||||
.timeout(std::time::Duration::from_secs(config.timeout_secs))
|
||||
.pool_max_idle_per_host(config.pool_size)
|
||||
.build()
|
||||
.map_err(|e| CxError::Config(format!("http client: {e}")))?;
|
||||
|
||||
Ok(Self {
|
||||
config: Arc::new(config),
|
||||
http,
|
||||
})
|
||||
}
|
||||
|
||||
/// Create from environment variables.
|
||||
pub fn from_env() -> CxResult<Self> {
|
||||
Self::new(PlatformConfig::from_env()?)
|
||||
}
|
||||
|
||||
pub fn nvidia(&self) -> NvidiaClient {
|
||||
NvidiaClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn supabase(&self) -> SupabaseClient {
|
||||
SupabaseClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn studio(&self) -> StudioClient {
|
||||
StudioClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn commissioning(&self) -> CommissioningClient {
|
||||
CommissioningClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn energy(&self) -> EnergyClient {
|
||||
EnergyClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn forma(&self) -> FormaClient {
|
||||
FormaClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn cxdocs(&self) -> CxDocsClient {
|
||||
CxDocsClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn app(&self) -> AppClient {
|
||||
AppClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn nemo(&self) -> NemoClient {
|
||||
NemoClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn codegen(&self) -> CodegenClient {
|
||||
CodegenClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn bridge(&self) -> BridgeClient {
|
||||
BridgeClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn ml(&self) -> MlClient {
|
||||
MlClient::new(self.http.clone(), self.config.clone())
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &PlatformConfig {
|
||||
&self.config
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn platform_from_default_config() {
|
||||
let config = PlatformConfig::default();
|
||||
let platform = CxPlatform::new(config).unwrap();
|
||||
assert_eq!(platform.config().base_url, "http://localhost");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn platform_clients_are_constructible() {
|
||||
let platform = CxPlatform::new(PlatformConfig::default()).unwrap();
|
||||
let _ = platform.nvidia();
|
||||
let _ = platform.studio();
|
||||
let _ = platform.commissioning();
|
||||
let _ = platform.energy();
|
||||
let _ = platform.forma();
|
||||
let _ = platform.cxdocs();
|
||||
let _ = platform.app();
|
||||
let _ = platform.nemo();
|
||||
let _ = platform.codegen();
|
||||
let _ = platform.bridge();
|
||||
let _ = platform.ml();
|
||||
let _ = platform.supabase();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CommissioningProject {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub phase: String,
|
||||
pub status: String,
|
||||
pub capacity_mw: Option<f64>,
|
||||
pub manager: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WorkflowStep {
|
||||
pub id: String,
|
||||
pub project_id: String,
|
||||
pub name: String,
|
||||
pub status: String,
|
||||
pub order: u32,
|
||||
pub completed_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Equipment {
|
||||
pub id: String,
|
||||
pub project_id: String,
|
||||
pub name: String,
|
||||
pub equipment_type: String,
|
||||
pub serial_number: Option<String>,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Incident {
|
||||
pub id: String,
|
||||
pub project_id: String,
|
||||
pub title: String,
|
||||
pub severity: String,
|
||||
pub status: String,
|
||||
pub description: Option<String>,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub resolved_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ComplianceRecord {
|
||||
pub id: String,
|
||||
pub project_id: String,
|
||||
pub standard: String,
|
||||
pub status: String,
|
||||
pub verified_at: Option<DateTime<Utc>>,
|
||||
pub notes: Option<String>,
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Generic API response wrapper matching .NET ApiResponse<T>.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiResponse<T> {
|
||||
pub success: bool,
|
||||
pub data: Option<T>,
|
||||
pub error: Option<String>,
|
||||
pub errors: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl<T> ApiResponse<T> {
|
||||
pub fn ok(data: T) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
data: Some(data),
|
||||
error: None,
|
||||
errors: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn err(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
data: None,
|
||||
error: Some(message.into()),
|
||||
errors: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Paginated response matching .NET PaginatedResponse<T>.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PaginatedResponse<T> {
|
||||
pub items: Vec<T>,
|
||||
pub total: i64,
|
||||
pub page: i32,
|
||||
pub per_page: i32,
|
||||
}
|
||||
|
||||
/// Health status response from /healthz endpoints.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthResponse {
|
||||
pub status: String,
|
||||
pub version: String,
|
||||
pub environment: String,
|
||||
pub services: Option<HashMap<String, ServiceHealth>>,
|
||||
}
|
||||
|
||||
/// Individual service health.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServiceHealth {
|
||||
pub status: String,
|
||||
pub latency_ms: Option<f64>,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Platform info from /info endpoint.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PlatformInfo {
|
||||
pub service: String,
|
||||
pub version: String,
|
||||
pub environment: String,
|
||||
pub uptime: String,
|
||||
pub runtime: String,
|
||||
pub host: String,
|
||||
}
|
||||
|
||||
/// Metrics from /metrics endpoint.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PlatformMetrics {
|
||||
pub uptime_seconds: f64,
|
||||
pub total_requests: u64,
|
||||
pub total_errors: u64,
|
||||
pub memory_mb: f64,
|
||||
pub threads: u32,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
/// Chat message matching .NET ChatMessage.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "system".into(),
|
||||
content: content.into(),
|
||||
timestamp: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "user".into(),
|
||||
content: content.into(),
|
||||
timestamp: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "assistant".into(),
|
||||
content: content.into(),
|
||||
timestamp: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
use chrono::{DateTime, NaiveDate, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct DamPrice {
|
||||
pub settlement_point: String,
|
||||
pub delivery_date: NaiveDate,
|
||||
pub hour_ending: u8,
|
||||
pub price: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RtPrice {
|
||||
pub settlement_point: String,
|
||||
pub price: f64,
|
||||
pub interval: String,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SystemDemand {
|
||||
pub demand_mw: f64,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub forecast_mw: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct WindGeneration {
|
||||
pub generation_mw: f64,
|
||||
pub capacity_mw: f64,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SolarGeneration {
|
||||
pub generation_mw: f64,
|
||||
pub capacity_mw: f64,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct MarketSummary {
|
||||
pub avg_dam_price: f64,
|
||||
pub avg_rt_price: f64,
|
||||
pub peak_demand_mw: f64,
|
||||
pub wind_generation_mw: f64,
|
||||
pub solar_generation_mw: f64,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ErcotMlPrediction {
|
||||
pub prediction_type: String,
|
||||
pub model: String,
|
||||
pub target_date: NaiveDate,
|
||||
pub target_hour: u8,
|
||||
pub predicted_value: f64,
|
||||
pub confidence: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub lower_bound: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub upper_bound: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ScadaSnapshot {
|
||||
pub device_id: String,
|
||||
pub soc_percent: f64,
|
||||
pub soh_percent: f64,
|
||||
pub power_kw: f64,
|
||||
pub voltage_v: f64,
|
||||
pub current_a: f64,
|
||||
pub temperature_c: f64,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CongestionSnapshot {
|
||||
pub constraint_name: String,
|
||||
pub shadow_price: f64,
|
||||
pub limit_mw: f64,
|
||||
pub flow_mw: f64,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct PriceVolatility {
|
||||
pub settlement_point: String,
|
||||
pub volatility: f64,
|
||||
pub period_hours: u32,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FormaTokenResponse {
|
||||
pub access_token: String,
|
||||
pub token_type: String,
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FormaProjectInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub status: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FormaElementInfo {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub element_type: String,
|
||||
pub project_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FormaAnalysisResult {
|
||||
pub id: String,
|
||||
pub analysis_type: String,
|
||||
pub status: String,
|
||||
pub result: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FormaComplianceResult {
|
||||
pub compliant: bool,
|
||||
pub violations: Vec<String>,
|
||||
pub element_urn: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AccFolder {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub parent_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AccIssue {
|
||||
pub id: String,
|
||||
pub title: String,
|
||||
pub status: String,
|
||||
pub priority: Option<String>,
|
||||
pub assigned_to: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AccChecklist {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub status: String,
|
||||
pub items_total: u32,
|
||||
pub items_completed: u32,
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Inference request matching .NET InferenceRequest.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InferenceRequest {
|
||||
pub model: String,
|
||||
pub messages: Vec<super::ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
}
|
||||
|
||||
/// Inference response matching .NET InferenceResponse.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InferenceResponse {
|
||||
pub content: String,
|
||||
pub model: String,
|
||||
pub usage: Option<TokenUsage>,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
/// Token usage stats.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TokenUsage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
/// CxModel record matching .NET CxModelRecord.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CxModelRecord {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub provider: String,
|
||||
pub model_id: String,
|
||||
pub status: String,
|
||||
pub capabilities: CxModelCapabilities,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub created_at: Option<DateTime<Utc>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
/// Model capabilities.
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CxModelCapabilities {
|
||||
pub chat: bool,
|
||||
pub streaming: bool,
|
||||
pub embedding: bool,
|
||||
pub vision: bool,
|
||||
pub code_gen: bool,
|
||||
pub function_calling: bool,
|
||||
pub max_context_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
/// CxModel inference request with profile routing.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CxInferenceRequest {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
pub messages: Vec<super::ChatMessage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub profile: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub task: Option<String>,
|
||||
}
|
||||
|
||||
/// Inference profile matching .NET InferenceProfile.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InferenceProfile {
|
||||
pub id: String,
|
||||
pub provider: String,
|
||||
pub model: String,
|
||||
pub display: String,
|
||||
pub timeout_seconds: u64,
|
||||
pub tier: String,
|
||||
}
|
||||
|
||||
/// Model health report.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CxModelHealthReport {
|
||||
pub status: String,
|
||||
pub providers: Vec<ProviderHealth>,
|
||||
}
|
||||
|
||||
/// Individual provider health.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ProviderHealth {
|
||||
pub name: String,
|
||||
pub status: String,
|
||||
pub latency_ms: Option<f64>,
|
||||
pub models_available: u32,
|
||||
}
|
||||
|
||||
/// Guardrail result from NVIDIA safety checks.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GuardrailResult {
|
||||
pub safe: bool,
|
||||
pub categories: Vec<String>,
|
||||
pub scores: Vec<f32>,
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
pub mod commissioning;
|
||||
pub mod common;
|
||||
pub mod energy;
|
||||
pub mod forma;
|
||||
pub mod inference;
|
||||
pub mod nemo;
|
||||
pub mod studio;
|
||||
pub mod telemetry;
|
||||
|
||||
pub use common::*;
|
||||
pub use inference::*;
|
||||
@@ -0,0 +1,120 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AgentRunRequest {
|
||||
pub goal: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub profile: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_steps: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AgentRunResponse {
|
||||
pub session_id: String,
|
||||
pub status: String,
|
||||
pub steps: Vec<AgentStep>,
|
||||
pub result: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct AgentStep {
|
||||
pub step_number: u32,
|
||||
pub action: String,
|
||||
pub result: Option<String>,
|
||||
pub duration_ms: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NemoChatRequest {
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub session_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub profile: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NemoChatResponse {
|
||||
pub reply: String,
|
||||
pub session_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub profile: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NemoSession {
|
||||
pub id: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub messages: Vec<super::ChatMessage>,
|
||||
pub profile: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NemoSessionSummary {
|
||||
pub id: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub message_count: u32,
|
||||
pub last_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolExecutionRequest {
|
||||
pub tool: String,
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolExecutionResult {
|
||||
pub tool: String,
|
||||
pub success: bool,
|
||||
pub result: Option<serde_json::Value>,
|
||||
pub error: Option<String>,
|
||||
pub duration_ms: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteRequest {
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub task_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RouteResponse {
|
||||
pub profile: super::InferenceProfile,
|
||||
pub confidence: f64,
|
||||
pub reasoning: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct NemoMetrics {
|
||||
pub total_sessions: u64,
|
||||
pub active_sessions: u64,
|
||||
pub total_agent_runs: u64,
|
||||
pub avg_steps_per_run: f64,
|
||||
pub tool_executions: u64,
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct User {
|
||||
pub id: String,
|
||||
pub email: String,
|
||||
pub name: Option<String>,
|
||||
pub role: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Conversation {
|
||||
pub id: String,
|
||||
pub user_id: String,
|
||||
pub title: Option<String>,
|
||||
pub message_count: u32,
|
||||
pub created_at: DateTime<Utc>,
|
||||
pub updated_at: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Deployment {
|
||||
pub id: String,
|
||||
pub service: String,
|
||||
pub version: String,
|
||||
pub status: String,
|
||||
pub target: String,
|
||||
pub created_at: DateTime<Utc>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Service {
|
||||
pub name: String,
|
||||
pub port: u16,
|
||||
pub status: String,
|
||||
pub version: String,
|
||||
pub health: String,
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TelemetryEvent {
|
||||
pub event_type: String,
|
||||
pub service: String,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub data: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct MemorySummary {
|
||||
pub total_mb: f64,
|
||||
pub used_mb: f64,
|
||||
pub gc_gen0: u64,
|
||||
pub gc_gen1: u64,
|
||||
pub gc_gen2: u64,
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
use backon::{ExponentialBuilder, Retryable};
|
||||
use std::future::Future;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::error::{CxError, CxResult};
|
||||
|
||||
/// Retry configuration matching .NET SDK retry/circuit-breaker policies.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryPolicy {
|
||||
pub max_attempts: u32,
|
||||
pub base_delay_ms: u64,
|
||||
pub max_delay_ms: u64,
|
||||
}
|
||||
|
||||
impl Default for RetryPolicy {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_attempts: 3,
|
||||
base_delay_ms: 200,
|
||||
max_delay_ms: 5000,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute an async operation with exponential backoff retry.
|
||||
pub async fn with_retry<F, Fut, T>(policy: &RetryPolicy, operation: F) -> CxResult<T>
|
||||
where
|
||||
F: FnMut() -> Fut,
|
||||
Fut: Future<Output = CxResult<T>>,
|
||||
{
|
||||
let backoff = ExponentialBuilder::default()
|
||||
.with_min_delay(std::time::Duration::from_millis(policy.base_delay_ms))
|
||||
.with_max_delay(std::time::Duration::from_millis(policy.max_delay_ms))
|
||||
.with_max_times(policy.max_attempts as usize);
|
||||
|
||||
operation
|
||||
.retry(backoff)
|
||||
.notify(|err: &CxError, dur| {
|
||||
warn!("Retrying after {dur:?} due to: {err}");
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn retry_succeeds_first_try() {
|
||||
let policy = RetryPolicy::default();
|
||||
let result = with_retry(&policy, || async { Ok::<_, CxError>(42) }).await;
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
}
|
||||
}
|
||||
Generated
+7
@@ -0,0 +1,7 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "cxai-test-crate"
|
||||
version = "0.1.0"
|
||||
@@ -0,0 +1,28 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2024"
|
||||
name = "cxai-test-crate"
|
||||
version = "0.1.0"
|
||||
build = false
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
description = "Test crate for CxAI Cargo registry"
|
||||
readme = false
|
||||
license = "MIT"
|
||||
|
||||
[lib]
|
||||
name = "cxai_test_crate"
|
||||
path = "src/lib.rs"
|
||||
Generated
+6
@@ -0,0 +1,6 @@
|
||||
[package]
|
||||
name = "cxai-test-crate"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
description = "Test crate for CxAI Cargo registry"
|
||||
license = "MIT"
|
||||
@@ -0,0 +1,14 @@
|
||||
pub fn add(left: u64, right: u64) -> u64 {
|
||||
left + right
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "927a31cbf65f78c3ef6b729631b2fc35335afe06",
|
||||
"dirty": true
|
||||
},
|
||||
"path_in_vcs": "services/cxcloud-rs/crates/agent"
|
||||
}
|
||||
Generated
+2767
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,94 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "cxcloud-agent"
|
||||
version = "0.1.0"
|
||||
build = false
|
||||
publish = ["cxai"]
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
readme = false
|
||||
|
||||
[[bin]]
|
||||
name = "agent-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies.anyhow]
|
||||
version = "1"
|
||||
|
||||
[dependencies.axum]
|
||||
version = "0.7"
|
||||
features = ["macros"]
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.cxcloud-common]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.cxcloud-proto]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.prost]
|
||||
version = "0.13"
|
||||
|
||||
[dependencies.prost-types]
|
||||
version = "0.13"
|
||||
|
||||
[dependencies.redis]
|
||||
version = "0.27"
|
||||
features = [
|
||||
"tokio-comp",
|
||||
"connection-manager",
|
||||
]
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tonic]
|
||||
version = "0.12"
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = [
|
||||
"v7",
|
||||
"serde",
|
||||
]
|
||||
Generated
+27
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "cxcloud-agent"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "agent-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
cxcloud-common = { workspace = true }
|
||||
cxcloud-proto = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tonic = { workspace = true }
|
||||
prost = { workspace = true }
|
||||
prost-types = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
@@ -0,0 +1,95 @@
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use cxcloud_common::{event::Event, redis_streams};
|
||||
|
||||
use crate::{grpc_clients, memory, planner, verifier, AppState};
|
||||
|
||||
const STREAM: &str = "events.raw";
|
||||
const GROUP: &str = "agent-group";
|
||||
const CONSUMER: &str = "agent-rs-1";
|
||||
const STREAM_OUTPUT: &str = "events.output";
|
||||
|
||||
/// Main autonomous processing loop: consume events → plan → approve → execute → verify → output.
|
||||
pub async fn run(state: Arc<AppState>) {
|
||||
redis_streams::ensure_consumer_group(&state.redis, STREAM, GROUP).await;
|
||||
|
||||
info!("Agent consumer started on {STREAM} (group: {GROUP})");
|
||||
|
||||
loop {
|
||||
let messages = match redis_streams::xreadgroup(&state.redis, GROUP, CONSUMER, STREAM, 1, 5000).await {
|
||||
Ok(msgs) => msgs,
|
||||
Err(e) => {
|
||||
error!(error = %e, "Stream read error");
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for (stream_id, fields) in messages {
|
||||
let event = match Event::from_stream_fields(&fields) {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
warn!(stream_id, "Failed to parse event from stream");
|
||||
let _ = redis_streams::xack(&state.redis, STREAM, GROUP, &stream_id).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
info!(event_id = %event.id, event_type = %event.r#type, "Processing event");
|
||||
|
||||
match process_event(&state, &event).await {
|
||||
Ok(()) => {
|
||||
info!(event_id = %event.id, "Event processed successfully");
|
||||
}
|
||||
Err(e) => {
|
||||
error!(event_id = %event.id, error = %e, "Event processing failed");
|
||||
}
|
||||
}
|
||||
|
||||
let _ = redis_streams::xack(&state.redis, STREAM, GROUP, &stream_id).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn process_event(state: &AppState, event: &Event) -> anyhow::Result<()> {
|
||||
// 1. Retrieve context from memory (ChromaDB RAG)
|
||||
let context = memory::context::retrieve(&state.http_client, &state.chromadb_url, event).await;
|
||||
|
||||
// 2. Generate plan via LLM (Ollama)
|
||||
let plan = planner::create_plan(&state.http_client, &state.ollama_url, event, &context).await?;
|
||||
|
||||
// 3. Request approval from Human Simulator
|
||||
let approval = grpc_clients::request_approval(&plan).await;
|
||||
match &approval {
|
||||
Ok(decision) if decision == "REJECTED" => {
|
||||
warn!(event_id = %event.id, "Plan rejected by policy engine");
|
||||
return Ok(());
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(event_id = %event.id, error = %e, "Approval request failed, auto-approving");
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// 4. Execute via Bot Service
|
||||
let execution = grpc_clients::execute_plan(&plan).await;
|
||||
|
||||
// 5. Verify execution results
|
||||
let verified = verifier::verify(&state.http_client, &state.ollama_url, &plan, &execution).await;
|
||||
|
||||
// 6. Store successful outcome in memory
|
||||
if verified {
|
||||
memory::store::save_outcome(&state.http_client, &state.chromadb_url, event, &plan).await;
|
||||
}
|
||||
|
||||
// 7. Publish to output stream
|
||||
let output_fields = vec![
|
||||
("event_id".to_string(), event.id.clone()),
|
||||
("plan".to_string(), plan.clone()),
|
||||
("status".to_string(), if verified { "completed" } else { "failed" }.to_string()),
|
||||
];
|
||||
redis_streams::xadd(&state.redis, STREAM_OUTPUT, &output_fields).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
use cxcloud_common::config::{env_or, env_port};
|
||||
use tracing::info;
|
||||
|
||||
/// Request approval from the Human Simulator service.
|
||||
/// Falls back to HTTP if gRPC is unavailable.
|
||||
pub async fn request_approval(plan: &str) -> Result<String, String> {
|
||||
let host = env_or("HUMAN_SIMULATOR_HOST", "localhost");
|
||||
let port = env_port("HUMAN_SIMULATOR_HTTP_PORT", 3002);
|
||||
let url = format!("http://{host}:{port}/approve");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({
|
||||
"request_id": uuid::Uuid::now_v7().to_string(),
|
||||
"plan": plan,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Human Simulator request failed: {e}"))?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
let body: serde_json::Value = resp.json().await.map_err(|e| format!("Parse error: {e}"))?;
|
||||
let decision = body
|
||||
.get("decision")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("APPROVED")
|
||||
.to_string();
|
||||
info!(decision, "Approval decision received");
|
||||
Ok(decision)
|
||||
} else {
|
||||
Err(format!("Human Simulator returned {}", resp.status()))
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a plan via the Bot Service.
|
||||
/// Falls back to HTTP if gRPC is unavailable.
|
||||
pub async fn execute_plan(plan: &str) -> Result<String, String> {
|
||||
let host = env_or("BOT_HOST", "localhost");
|
||||
let port = env_port("BOT_HTTP_PORT", 8002);
|
||||
let url = format!("http://{host}:{port}/execute");
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let resp = client
|
||||
.post(&url)
|
||||
.json(&serde_json::json!({
|
||||
"plan_id": uuid::Uuid::now_v7().to_string(),
|
||||
"plan": plan,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Bot execution request failed: {e}"))?;
|
||||
|
||||
if resp.status().is_success() {
|
||||
let body: serde_json::Value = resp.json().await.map_err(|e| format!("Parse error: {e}"))?;
|
||||
info!("Execution completed");
|
||||
Ok(serde_json::to_string(&body).unwrap_or_default())
|
||||
} else {
|
||||
Err(format!("Bot Service returned {}", resp.status()))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,68 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tonic::{Request, Response, Status};
|
||||
use tracing::info;
|
||||
|
||||
use cxcloud_proto::agent::{
|
||||
agent_service_server::{AgentService, AgentServiceServer},
|
||||
GetEventStatusRequest, GetEventStatusResponse, ProcessEventRequest, ProcessEventResponse,
|
||||
};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub struct AgentServiceImpl {
|
||||
state: Arc<AppState>,
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl AgentService for AgentServiceImpl {
|
||||
async fn process_event(
|
||||
&self,
|
||||
request: Request<ProcessEventRequest>,
|
||||
) -> Result<Response<ProcessEventResponse>, Status> {
|
||||
let req = request.into_inner();
|
||||
let event = req
|
||||
.event
|
||||
.ok_or_else(|| Status::invalid_argument("event is required"))?;
|
||||
|
||||
info!(event_id = %event.id, "Processing event via gRPC");
|
||||
|
||||
// In a full implementation, this would trigger the same pipeline as the consumer
|
||||
let plan_id = uuid::Uuid::now_v7().to_string();
|
||||
|
||||
Ok(Response::new(ProcessEventResponse {
|
||||
plan_id: plan_id.clone(),
|
||||
plan: None, // Plan would be populated after LLM generation
|
||||
}))
|
||||
}
|
||||
|
||||
async fn get_event_status(
|
||||
&self,
|
||||
request: Request<GetEventStatusRequest>,
|
||||
) -> Result<Response<GetEventStatusResponse>, Status> {
|
||||
let req = request.into_inner();
|
||||
|
||||
Ok(Response::new(GetEventStatusResponse {
|
||||
event_id: req.event_id,
|
||||
plan_id: String::new(),
|
||||
status: 0, // UNSPECIFIED
|
||||
summary: "Status lookup not yet implemented".to_string(),
|
||||
retry_count: 0,
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve(state: Arc<AppState>, port: u16) -> anyhow::Result<()> {
|
||||
let addr = format!("0.0.0.0:{port}").parse()?;
|
||||
let service = AgentServiceImpl { state };
|
||||
|
||||
info!(port, "Agent gRPC server starting");
|
||||
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(AgentServiceServer::new(service))
|
||||
.serve(addr)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,115 @@
|
||||
mod consumer;
|
||||
mod grpc_clients;
|
||||
mod grpc_server;
|
||||
pub mod memory;
|
||||
mod planner;
|
||||
mod verifier;
|
||||
|
||||
use axum::{extract::State, response::Json, routing::get, Router};
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use cxcloud_common::{
|
||||
config::{env_or, env_port},
|
||||
health::HealthResponse,
|
||||
redis_streams, telemetry,
|
||||
};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub redis: redis::aio::ConnectionManager,
|
||||
pub ollama_url: String,
|
||||
pub chromadb_url: String,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let redis_url = env_or("REDIS_URL", "redis://localhost:6379");
|
||||
let otel_endpoint = env_or("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317");
|
||||
let log_level = env_or("LOG_LEVEL", "info");
|
||||
let ollama_url = env_or("OLLAMA_URL", "http://localhost:11434");
|
||||
let chromadb_url = env_or("CHROMADB_URL", "http://localhost:8000");
|
||||
|
||||
telemetry::init("agent-service", &otel_endpoint, &log_level);
|
||||
|
||||
let redis_conn = redis_streams::connect(&redis_url).await?;
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.build()?;
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
redis: redis_conn.clone(),
|
||||
ollama_url: ollama_url.clone(),
|
||||
chromadb_url: chromadb_url.clone(),
|
||||
http_client: http_client.clone(),
|
||||
});
|
||||
|
||||
// Spawn Redis stream consumer (autonomous processing loop)
|
||||
let consumer_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
consumer::run(consumer_state).await;
|
||||
});
|
||||
|
||||
// Spawn gRPC server
|
||||
let grpc_port = env_port("AGENT_GRPC_PORT", 50051);
|
||||
let grpc_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = grpc_server::serve(grpc_state, grpc_port).await {
|
||||
tracing::error!(error = %e, "gRPC server failed");
|
||||
}
|
||||
});
|
||||
|
||||
// HTTP server
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/metrics", get(metrics))
|
||||
.with_state(state);
|
||||
|
||||
let http_port = env_port("AGENT_HTTP_PORT", 8001);
|
||||
info!(http_port, grpc_port, "Agent service starting");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{http_port}")).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
telemetry::shutdown();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
|
||||
let mut conn = state.redis.clone();
|
||||
let redis_ok = redis_streams::ping(&mut conn).await;
|
||||
|
||||
let ollama_ok = state
|
||||
.http_client
|
||||
.get(format!("{}/api/tags", state.ollama_url))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false);
|
||||
|
||||
let chromadb_ok = state
|
||||
.http_client
|
||||
.get(format!("{}/api/v1/heartbeat", state.chromadb_url))
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
.unwrap_or(false);
|
||||
|
||||
let resp = HealthResponse::healthy("agent-service")
|
||||
.with_dependency("redis", if redis_ok { "healthy" } else { "unhealthy" })
|
||||
.with_dependency("ollama", if ollama_ok { "healthy" } else { "unhealthy" })
|
||||
.with_dependency("chromadb", if chromadb_ok { "healthy" } else { "unhealthy" })
|
||||
.compute_status();
|
||||
|
||||
Json(resp)
|
||||
}
|
||||
|
||||
async fn metrics() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"service": "agent-service",
|
||||
"events_processed": 0,
|
||||
"plans_created": 0,
|
||||
"plans_rejected": 0,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
use cxcloud_common::event::Event;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
const COLLECTION: &str = "cxcloud_memory";
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct QueryRequest {
|
||||
query_texts: Vec<String>,
|
||||
n_results: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct QueryResponse {
|
||||
documents: Option<Vec<Vec<String>>>,
|
||||
}
|
||||
|
||||
/// Retrieve relevant context from ChromaDB via similarity search.
|
||||
pub async fn retrieve(
|
||||
client: &reqwest::Client,
|
||||
chromadb_url: &str,
|
||||
event: &Event,
|
||||
) -> String {
|
||||
let query_text = format!(
|
||||
"Event type: {} source: {} payload: {}",
|
||||
event.r#type,
|
||||
event.source,
|
||||
serde_json::to_string(&event.payload).unwrap_or_default()
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.post(format!(
|
||||
"{chromadb_url}/api/v1/collections/{COLLECTION}/query"
|
||||
))
|
||||
.json(&QueryRequest {
|
||||
query_texts: vec![query_text],
|
||||
n_results: 5,
|
||||
})
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok(r) if r.status().is_success() => {
|
||||
if let Ok(body) = r.json::<QueryResponse>().await {
|
||||
if let Some(docs) = body.documents {
|
||||
let context: Vec<String> = docs.into_iter().flatten().collect();
|
||||
if !context.is_empty() {
|
||||
debug!(count = context.len(), "Retrieved context from memory");
|
||||
return context.join("\n---\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(r) => {
|
||||
debug!(status = %r.status(), "ChromaDB query returned non-success");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to query ChromaDB for context");
|
||||
}
|
||||
}
|
||||
|
||||
String::new()
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
pub mod context;
|
||||
pub mod store;
|
||||
@@ -0,0 +1,57 @@
|
||||
use cxcloud_common::event::Event;
|
||||
use serde::Serialize;
|
||||
use tracing::{debug, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
const COLLECTION: &str = "cxcloud_memory";
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct AddRequest {
|
||||
ids: Vec<String>,
|
||||
documents: Vec<String>,
|
||||
metadatas: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Save a successful outcome to ChromaDB for future RAG retrieval.
|
||||
pub async fn save_outcome(
|
||||
client: &reqwest::Client,
|
||||
chromadb_url: &str,
|
||||
event: &Event,
|
||||
plan: &str,
|
||||
) {
|
||||
let doc = format!(
|
||||
"Event: {} ({})\nPlan: {}",
|
||||
event.r#type, event.source, plan
|
||||
);
|
||||
|
||||
let metadata = serde_json::json!({
|
||||
"event_id": event.id,
|
||||
"event_type": event.r#type,
|
||||
"source": event.source,
|
||||
"timestamp": event.timestamp.to_rfc3339(),
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.post(format!(
|
||||
"{chromadb_url}/api/v1/collections/{COLLECTION}/add"
|
||||
))
|
||||
.json(&AddRequest {
|
||||
ids: vec![Uuid::now_v7().to_string()],
|
||||
documents: vec![doc],
|
||||
metadatas: vec![metadata],
|
||||
})
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok(r) if r.status().is_success() => {
|
||||
debug!(event_id = %event.id, "Saved outcome to memory");
|
||||
}
|
||||
Ok(r) => {
|
||||
warn!(status = %r.status(), "Failed to save to ChromaDB");
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "Failed to connect to ChromaDB for save");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
use cxcloud_common::event::Event;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OllamaRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaResponse {
|
||||
response: String,
|
||||
}
|
||||
|
||||
/// Create a plan by sending the event + context to the local LLM (Ollama).
|
||||
pub async fn create_plan(
|
||||
client: &reqwest::Client,
|
||||
ollama_url: &str,
|
||||
event: &Event,
|
||||
context: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "mistral:7b".to_string());
|
||||
|
||||
let prompt = format!(
|
||||
r#"You are an autonomous agent. Analyze this event and create an execution plan.
|
||||
|
||||
Event:
|
||||
- ID: {}
|
||||
- Source: {}
|
||||
- Type: {}
|
||||
- Payload: {}
|
||||
|
||||
Retrieved Context:
|
||||
{}
|
||||
|
||||
Respond with a JSON plan containing:
|
||||
- "goal": high-level description
|
||||
- "reasoning": array of reasoning steps
|
||||
- "tool_calls": array of tool invocations (tool_name, parameters)
|
||||
|
||||
Allowed tools: http_request, file_write, file_read, db_query, notification"#,
|
||||
event.id,
|
||||
event.source,
|
||||
event.r#type,
|
||||
serde_json::to_string_pretty(&event.payload).unwrap_or_default(),
|
||||
if context.is_empty() { "No prior context available." } else { context },
|
||||
);
|
||||
|
||||
info!(event_id = %event.id, model, "Generating plan via LLM");
|
||||
|
||||
let resp = client
|
||||
.post(format!("{ollama_url}/api/generate"))
|
||||
.json(&OllamaRequest {
|
||||
model,
|
||||
prompt,
|
||||
stream: false,
|
||||
})
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Ollama returned {status}: {body}");
|
||||
}
|
||||
|
||||
let ollama_resp: OllamaResponse = resp.json().await?;
|
||||
info!(event_id = %event.id, "Plan generated");
|
||||
|
||||
Ok(ollama_resp.response)
|
||||
}
|
||||
@@ -0,0 +1,70 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OllamaRequest {
|
||||
model: String,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaResponse {
|
||||
response: String,
|
||||
}
|
||||
|
||||
/// Verify execution results by asking the LLM to evaluate success.
|
||||
pub async fn verify(
|
||||
client: &reqwest::Client,
|
||||
ollama_url: &str,
|
||||
plan: &str,
|
||||
execution_result: &Result<String, String>,
|
||||
) -> bool {
|
||||
let model = std::env::var("OLLAMA_MODEL").unwrap_or_else(|_| "mistral:7b".to_string());
|
||||
|
||||
let exec_summary = match execution_result {
|
||||
Ok(result) => format!("Execution succeeded:\n{result}"),
|
||||
Err(err) => format!("Execution failed:\n{err}"),
|
||||
};
|
||||
|
||||
let prompt = format!(
|
||||
r#"You are a verification agent. Determine if this plan was executed successfully.
|
||||
|
||||
Plan:
|
||||
{plan}
|
||||
|
||||
Execution Result:
|
||||
{exec_summary}
|
||||
|
||||
Respond with ONLY "PASS" or "FAIL" followed by a brief reason."#,
|
||||
);
|
||||
|
||||
let resp = client
|
||||
.post(format!("{ollama_url}/api/generate"))
|
||||
.json(&OllamaRequest {
|
||||
model,
|
||||
prompt,
|
||||
stream: false,
|
||||
})
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok(r) if r.status().is_success() => {
|
||||
if let Ok(body) = r.json::<OllamaResponse>().await {
|
||||
let passed = body.response.trim().starts_with("PASS");
|
||||
info!(passed, reason = %body.response.lines().next().unwrap_or(""), "Verification result");
|
||||
return passed;
|
||||
}
|
||||
}
|
||||
Ok(r) => {
|
||||
tracing::warn!(status = %r.status(), "Verification LLM returned error");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(error = %e, "Verification request failed");
|
||||
}
|
||||
}
|
||||
|
||||
// Default to pass if verification is unavailable
|
||||
true
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "927a31cbf65f78c3ef6b729631b2fc35335afe06",
|
||||
"dirty": true
|
||||
},
|
||||
"path_in_vcs": "services/cxcloud-rs/crates/api-gateway"
|
||||
}
|
||||
Generated
+3027
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,100 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "cxcloud-api-gateway"
|
||||
version = "0.1.0"
|
||||
build = false
|
||||
publish = ["cxai"]
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
readme = false
|
||||
|
||||
[[bin]]
|
||||
name = "api-gateway"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies.anyhow]
|
||||
version = "1"
|
||||
|
||||
[dependencies.axum]
|
||||
version = "0.7"
|
||||
features = ["macros"]
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.cxcloud-common]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.cxcloud-proto]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.governor]
|
||||
version = "0.6"
|
||||
|
||||
[dependencies.jsonwebtoken]
|
||||
version = "9"
|
||||
|
||||
[dependencies.redis]
|
||||
version = "0.27"
|
||||
features = [
|
||||
"tokio-comp",
|
||||
"connection-manager",
|
||||
]
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tower]
|
||||
version = "0.4"
|
||||
|
||||
[dependencies.tower-http]
|
||||
version = "0.5"
|
||||
features = [
|
||||
"fs",
|
||||
"cors",
|
||||
"trace",
|
||||
"timeout",
|
||||
]
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = [
|
||||
"v7",
|
||||
"serde",
|
||||
]
|
||||
+27
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "cxcloud-api-gateway"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "api-gateway"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
cxcloud-common = { workspace = true }
|
||||
cxcloud-proto = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
governor = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
@@ -0,0 +1,79 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use chrono::Utc;
|
||||
use jsonwebtoken::{encode, EncodingKey, Header};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
iss: String,
|
||||
exp: usize,
|
||||
iat: usize,
|
||||
}
|
||||
|
||||
/// Middleware that validates X-API-Key and injects a short-lived JWT.
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<Arc<AppState>>,
|
||||
mut req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let path = req.uri().path();
|
||||
|
||||
// Skip auth for health/ready/metrics
|
||||
if matches!(path, "/health" | "/ready" | "/metrics") {
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
let api_key = extract_api_key(req.headers());
|
||||
match api_key {
|
||||
Some(key) if key == state.config.api_key => {
|
||||
// Mint a short-lived JWT for downstream services
|
||||
let now = Utc::now().timestamp() as usize;
|
||||
let claims = Claims {
|
||||
sub: "api-client".to_string(),
|
||||
iss: state.config.jwt_issuer.clone(),
|
||||
exp: now + state.config.jwt_expiry_seconds as usize,
|
||||
iat: now,
|
||||
};
|
||||
|
||||
if let Ok(token) = encode(
|
||||
&Header::default(),
|
||||
&claims,
|
||||
&EncodingKey::from_secret(state.config.jwt_secret.as_bytes()),
|
||||
) {
|
||||
req.headers_mut().insert(
|
||||
"Authorization",
|
||||
format!("Bearer {token}").parse().unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(next.run(req).await)
|
||||
}
|
||||
Some(_) => {
|
||||
warn!("Invalid API key");
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
None => {
|
||||
warn!("Missing X-API-Key header");
|
||||
Err(StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_api_key(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("x-api-key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
use cxcloud_common::config::{env_or, env_port, env_u64, ServiceConfig};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GatewayConfig {
|
||||
pub base: ServiceConfig,
|
||||
pub port: u16,
|
||||
pub jwt_secret: String,
|
||||
pub jwt_issuer: String,
|
||||
pub jwt_expiry_seconds: u64,
|
||||
pub api_key: String,
|
||||
pub rate_limit_rps: u32,
|
||||
pub rate_limit_burst: u32,
|
||||
pub upstream_input_gateway: String,
|
||||
pub upstream_agent: String,
|
||||
pub upstream_bot: String,
|
||||
pub upstream_human_simulator: String,
|
||||
pub upstream_output: String,
|
||||
}
|
||||
|
||||
impl GatewayConfig {
|
||||
pub fn from_env() -> Self {
|
||||
let agent_host = env_or("AGENT_HOST", "localhost");
|
||||
let agent_port = env_port("AGENT_HTTP_PORT", 8001);
|
||||
let bot_host = env_or("BOT_HOST", "localhost");
|
||||
let bot_port = env_port("BOT_HTTP_PORT", 8002);
|
||||
let input_host = env_or("INPUT_GATEWAY_HOST", "localhost");
|
||||
let input_port = env_port("INPUT_GATEWAY_PORT", 3001);
|
||||
let human_host = env_or("HUMAN_SIMULATOR_HOST", "localhost");
|
||||
let human_port = env_port("HUMAN_SIMULATOR_HTTP_PORT", 3002);
|
||||
let output_host = env_or("OUTPUT_HOST", "localhost");
|
||||
let output_port = env_port("OUTPUT_HTTP_PORT", 8003);
|
||||
|
||||
Self {
|
||||
base: ServiceConfig::from_env("api-gateway"),
|
||||
port: env_port("API_GATEWAY_PORT", 8080),
|
||||
jwt_secret: env_or("JWT_SECRET", "dev-secret-change-me"),
|
||||
jwt_issuer: env_or("JWT_ISSUER", "cxcloud-api-gateway"),
|
||||
jwt_expiry_seconds: env_u64("JWT_EXPIRY_SECONDS", 300),
|
||||
api_key: env_or("API_KEY_DEV", "dev-test-key"),
|
||||
rate_limit_rps: env_u64("RATE_LIMIT_RPS", 10) as u32,
|
||||
rate_limit_burst: env_u64("RATE_LIMIT_BURST", 20) as u32,
|
||||
upstream_input_gateway: format!("http://{input_host}:{input_port}"),
|
||||
upstream_agent: format!("http://{agent_host}:{agent_port}"),
|
||||
upstream_bot: format!("http://{bot_host}:{bot_port}"),
|
||||
upstream_human_simulator: format!("http://{human_host}:{human_port}"),
|
||||
upstream_output: format!("http://{output_host}:{output_port}"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn upstream_for_prefix(&self, prefix: &str) -> Option<&str> {
|
||||
match prefix {
|
||||
"webhook" => Some(&self.upstream_input_gateway),
|
||||
"agent" => Some(&self.upstream_agent),
|
||||
"bot" => Some(&self.upstream_bot),
|
||||
"policies" => Some(&self.upstream_human_simulator),
|
||||
"deliver" => Some(&self.upstream_output),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
mod auth;
|
||||
mod config;
|
||||
mod proxy;
|
||||
mod ratelimit;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{any, get},
|
||||
Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use cxcloud_common::{config::env_port, health::HealthResponse, telemetry};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub config: config::GatewayConfig,
|
||||
pub redis: redis::aio::ConnectionManager,
|
||||
pub http_client: reqwest::Client,
|
||||
pub rate_limiter: ratelimit::RateLimiter,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let cfg = config::GatewayConfig::from_env();
|
||||
telemetry::init("api-gateway", &cfg.base.otel_endpoint, &cfg.base.log_level);
|
||||
|
||||
let redis_conn = cxcloud_common::redis_streams::connect(&cfg.base.redis_url).await?;
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
let rate_limiter = ratelimit::new(cfg.rate_limit_rps, cfg.rate_limit_burst);
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
config: cfg.clone(),
|
||||
redis: redis_conn,
|
||||
http_client,
|
||||
rate_limiter,
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/ready", get(ready))
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/webhook/{*path}", any(proxy::proxy_handler))
|
||||
.route("/agent/{*path}", any(proxy::proxy_handler))
|
||||
.route("/bot/{*path}", any(proxy::proxy_handler))
|
||||
.route("/policies/{*path}", any(proxy::proxy_handler))
|
||||
.route("/deliver/{*path}", any(proxy::proxy_handler))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
auth::auth_middleware,
|
||||
))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
ratelimit::rate_limit_middleware,
|
||||
))
|
||||
.with_state(state);
|
||||
|
||||
let port = env_port("API_GATEWAY_PORT", 8080);
|
||||
let addr = format!("0.0.0.0:{port}");
|
||||
info!(port, "API Gateway starting");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
telemetry::shutdown();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health(State(state): State<Arc<AppState>>) -> Json<HealthResponse> {
|
||||
let redis_ok = cxcloud_common::redis_streams::ping(&state.redis).await;
|
||||
let resp = HealthResponse::healthy("api-gateway")
|
||||
.with_dependency("redis", if redis_ok { "healthy" } else { "unhealthy" })
|
||||
.compute_status();
|
||||
Json(resp)
|
||||
}
|
||||
|
||||
async fn ready(State(state): State<Arc<AppState>>) -> StatusCode {
|
||||
if cxcloud_common::redis_streams::ping(&state.redis).await {
|
||||
StatusCode::OK
|
||||
} else {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
async fn metrics() -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"service": "api-gateway",
|
||||
"requests_total": 0,
|
||||
"auth_failures": 0,
|
||||
"rate_limited": 0,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
/// Reverse proxy handler: routes requests to upstream services based on path prefix.
|
||||
pub async fn proxy_handler(
|
||||
State(state): State<Arc<AppState>>,
|
||||
req: Request<Body>,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let path = req.uri().path();
|
||||
let prefix = path
|
||||
.trim_start_matches('/')
|
||||
.split('/')
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
|
||||
let upstream_base = state
|
||||
.config
|
||||
.upstream_for_prefix(prefix)
|
||||
.ok_or(StatusCode::NOT_FOUND)?;
|
||||
|
||||
// Strip the prefix and forward the rest
|
||||
let remaining = path
|
||||
.trim_start_matches('/')
|
||||
.strip_prefix(prefix)
|
||||
.unwrap_or("");
|
||||
|
||||
let upstream_url = format!("{upstream_base}{remaining}");
|
||||
info!(upstream_url, method = %req.method(), "Proxying request");
|
||||
|
||||
let method = req.method().clone();
|
||||
let headers = req.headers().clone();
|
||||
|
||||
// Read request body
|
||||
let body_bytes = axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024)
|
||||
.await
|
||||
.map_err(|_| StatusCode::BAD_REQUEST)?;
|
||||
|
||||
// Build upstream request
|
||||
let mut upstream_req = state.http_client.request(method, &upstream_url);
|
||||
|
||||
// Forward relevant headers
|
||||
for (name, value) in headers.iter() {
|
||||
if !matches!(
|
||||
name.as_str(),
|
||||
"host" | "connection" | "transfer-encoding" | "x-api-key"
|
||||
) {
|
||||
upstream_req = upstream_req.header(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
if !body_bytes.is_empty() {
|
||||
upstream_req = upstream_req.body(body_bytes);
|
||||
}
|
||||
|
||||
// Send to upstream
|
||||
let upstream_resp = upstream_req.send().await.map_err(|e| {
|
||||
error!(error = %e, upstream_url, "Upstream request failed");
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?;
|
||||
|
||||
// Convert upstream response back to axum response
|
||||
let status = StatusCode::from_u16(upstream_resp.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
let mut response_builder = Response::builder().status(status);
|
||||
|
||||
for (name, value) in upstream_resp.headers() {
|
||||
response_builder = response_builder.header(name, value);
|
||||
}
|
||||
|
||||
let body = upstream_resp.bytes().await.map_err(|e| {
|
||||
error!(error = %e, "Failed to read upstream response body");
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?;
|
||||
|
||||
response_builder
|
||||
.body(Body::from(body))
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
use std::num::NonZeroU32;
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use governor::{Quota, RateLimiter as GovLimiter, clock::DefaultClock, state::{InMemoryState, NotKeyed}};
|
||||
|
||||
pub type RateLimiter = Arc<GovLimiter<NotKeyed, InMemoryState, DefaultClock>>;
|
||||
|
||||
pub fn new(rps: u32, burst: u32) -> RateLimiter {
|
||||
let quota = Quota::per_second(NonZeroU32::new(rps).unwrap_or(NonZeroU32::new(10).unwrap()))
|
||||
.allow_burst(NonZeroU32::new(burst).unwrap_or(NonZeroU32::new(20).unwrap()));
|
||||
Arc::new(GovLimiter::direct(quota))
|
||||
}
|
||||
|
||||
pub async fn rate_limit_middleware(
|
||||
State(state): State<Arc<crate::AppState>>,
|
||||
req: Request<Body>,
|
||||
next: Next,
|
||||
) -> Result<Response, StatusCode> {
|
||||
let path = req.uri().path();
|
||||
|
||||
// Skip rate limiting for health/ready/metrics
|
||||
if matches!(path, "/health" | "/ready" | "/metrics") {
|
||||
return Ok(next.run(req).await);
|
||||
}
|
||||
|
||||
match state.rate_limiter.check() {
|
||||
Ok(_) => Ok(next.run(req).await),
|
||||
Err(_) => {
|
||||
tracing::warn!("Rate limit exceeded");
|
||||
Err(StatusCode::TOO_MANY_REQUESTS)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "927a31cbf65f78c3ef6b729631b2fc35335afe06",
|
||||
"dirty": true
|
||||
},
|
||||
"path_in_vcs": "services/cxcloud-rs/crates/bot"
|
||||
}
|
||||
Generated
+2767
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,90 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "cxcloud-bot"
|
||||
version = "0.1.0"
|
||||
build = false
|
||||
publish = ["cxai"]
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
readme = false
|
||||
|
||||
[[bin]]
|
||||
name = "bot-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies.anyhow]
|
||||
version = "1"
|
||||
|
||||
[dependencies.axum]
|
||||
version = "0.7"
|
||||
features = ["macros"]
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.cxcloud-common]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.cxcloud-proto]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.prost]
|
||||
version = "0.13"
|
||||
|
||||
[dependencies.prost-types]
|
||||
version = "0.13"
|
||||
|
||||
[dependencies.reqwest]
|
||||
version = "0.12"
|
||||
features = [
|
||||
"json",
|
||||
"rustls-tls",
|
||||
]
|
||||
default-features = false
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tonic]
|
||||
version = "0.12"
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.url]
|
||||
version = "2"
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = [
|
||||
"v7",
|
||||
"serde",
|
||||
]
|
||||
Generated
+27
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "cxcloud-bot"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "bot-service"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
cxcloud-common = { workspace = true }
|
||||
cxcloud-proto = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
tonic = { workspace = true }
|
||||
prost = { workspace = true }
|
||||
prost-types = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
url = "2"
|
||||
@@ -0,0 +1,107 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Per-tool circuit breaker. Opens after `threshold` consecutive failures.
|
||||
pub struct CircuitBreaker {
|
||||
breakers: Mutex<HashMap<String, BreakerState>>,
|
||||
threshold: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BreakerState {
|
||||
failures: u32,
|
||||
open: bool,
|
||||
}
|
||||
|
||||
impl CircuitBreaker {
|
||||
pub fn new(threshold: u32) -> Self {
|
||||
Self {
|
||||
breakers: Mutex::new(HashMap::new()),
|
||||
threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool's circuit is open (should not be called).
|
||||
pub fn is_open(&self, tool: &str) -> bool {
|
||||
self.breakers
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get(tool)
|
||||
.map(|b| b.open)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Record a success, resetting the failure counter.
|
||||
pub fn record_success(&self, tool: &str) {
|
||||
let mut breakers = self.breakers.lock().unwrap();
|
||||
breakers.insert(
|
||||
tool.to_string(),
|
||||
BreakerState {
|
||||
failures: 0,
|
||||
open: false,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
/// Record a failure. Opens the circuit if threshold is reached.
|
||||
pub fn record_failure(&self, tool: &str) {
|
||||
let mut breakers = self.breakers.lock().unwrap();
|
||||
let state = breakers.entry(tool.to_string()).or_insert(BreakerState {
|
||||
failures: 0,
|
||||
open: false,
|
||||
});
|
||||
state.failures += 1;
|
||||
if state.failures >= self.threshold {
|
||||
state.open = true;
|
||||
tracing::warn!(tool, failures = state.failures, "Circuit breaker opened");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_tool_is_closed() {
|
||||
let cb = CircuitBreaker::new(3);
|
||||
assert!(!cb.is_open("unknown-tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn opens_after_threshold_failures() {
|
||||
let cb = CircuitBreaker::new(3);
|
||||
cb.record_failure("flaky");
|
||||
assert!(!cb.is_open("flaky"));
|
||||
cb.record_failure("flaky");
|
||||
assert!(!cb.is_open("flaky"));
|
||||
cb.record_failure("flaky");
|
||||
assert!(cb.is_open("flaky"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn success_resets_failures() {
|
||||
let cb = CircuitBreaker::new(3);
|
||||
cb.record_failure("tool");
|
||||
cb.record_failure("tool");
|
||||
cb.record_success("tool");
|
||||
cb.record_failure("tool");
|
||||
assert!(!cb.is_open("tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn independent_tools_are_independent() {
|
||||
let cb = CircuitBreaker::new(2);
|
||||
cb.record_failure("a");
|
||||
cb.record_failure("a");
|
||||
assert!(cb.is_open("a"));
|
||||
assert!(!cb.is_open("b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn threshold_of_one_opens_immediately() {
|
||||
let cb = CircuitBreaker::new(1);
|
||||
cb.record_failure("tool");
|
||||
assert!(cb.is_open("tool"));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use tracing::{error, info};
|
||||
|
||||
/// Execute tool calls, respecting dependency order.
|
||||
pub async fn execute_tool_calls(
|
||||
registry: &ToolRegistry,
|
||||
tool_calls: &[serde_json::Value],
|
||||
) -> Vec<serde_json::Value> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for call in tool_calls {
|
||||
let tool_name = call
|
||||
.get("tool_name")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let params = call
|
||||
.get("parameters")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::Value::Null);
|
||||
|
||||
// Check circuit breaker
|
||||
if registry.circuit_breaker.is_open(tool_name) {
|
||||
error!(tool = tool_name, "Circuit breaker open, skipping");
|
||||
results.push(serde_json::json!({
|
||||
"tool_call_id": call.get("id").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"tool_name": tool_name,
|
||||
"success": false,
|
||||
"error_message": "Circuit breaker open",
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let result = registry.execute(tool_name, ¶ms).await;
|
||||
let duration_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
match result {
|
||||
Ok(output) => {
|
||||
registry.circuit_breaker.record_success(tool_name);
|
||||
info!(tool = tool_name, duration_ms, "Tool execution succeeded");
|
||||
results.push(serde_json::json!({
|
||||
"tool_call_id": call.get("id").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"tool_name": tool_name,
|
||||
"success": true,
|
||||
"output": output,
|
||||
"duration_ms": duration_ms,
|
||||
}));
|
||||
}
|
||||
Err(e) => {
|
||||
registry.circuit_breaker.record_failure(tool_name);
|
||||
error!(tool = tool_name, error = %e, "Tool execution failed");
|
||||
results.push(serde_json::json!({
|
||||
"tool_call_id": call.get("id").and_then(|v| v.as_str()).unwrap_or(""),
|
||||
"tool_name": tool_name,
|
||||
"success": false,
|
||||
"error_message": e.to_string(),
|
||||
"duration_ms": duration_ms,
|
||||
}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
@@ -0,0 +1,101 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use tonic::{Request, Response, Status};
|
||||
use tracing::info;
|
||||
|
||||
use cxcloud_proto::bot::{
|
||||
bot_service_server::{BotService, BotServiceServer},
|
||||
ExecutionRequest, ExecutionResponse, ListToolsRequest, ListToolsResponse, ToolInfo,
|
||||
ToolResult,
|
||||
};
|
||||
|
||||
use crate::AppState;
|
||||
|
||||
pub struct BotServiceImpl {
|
||||
state: Arc<AppState>,
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl BotService for BotServiceImpl {
|
||||
async fn execute(
|
||||
&self,
|
||||
request: Request<ExecutionRequest>,
|
||||
) -> Result<Response<ExecutionResponse>, Status> {
|
||||
let req = request.into_inner();
|
||||
info!(plan_id = %req.plan_id, tools = req.tool_calls.len(), "Execute request via gRPC");
|
||||
|
||||
let tool_calls: Vec<serde_json::Value> = req
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
serde_json::json!({
|
||||
"id": tc.id,
|
||||
"tool_name": tc.tool_name,
|
||||
"parameters": tc.parameters.as_ref().map(|p| format!("{:?}", p)).unwrap_or_default(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = crate::executor::execute_tool_calls(&self.state.registry, &tool_calls).await;
|
||||
|
||||
let tool_results: Vec<ToolResult> = results
|
||||
.iter()
|
||||
.map(|r| ToolResult {
|
||||
tool_call_id: r.get("tool_call_id").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
tool_name: r.get("tool_name").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
success: r.get("success").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
output: None,
|
||||
error_message: r.get("error_message").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
error_code: String::new(),
|
||||
retries_used: 0,
|
||||
duration_ms: r.get("duration_ms").and_then(|v| v.as_i64()).unwrap_or(0),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let all_success = tool_results.iter().all(|r| r.success);
|
||||
|
||||
Ok(Response::new(ExecutionResponse {
|
||||
execution_id: uuid::Uuid::now_v7().to_string(),
|
||||
results: tool_results,
|
||||
status: if all_success { 1 } else { 2 },
|
||||
started_at: None,
|
||||
completed_at: None,
|
||||
duration_ms: 0,
|
||||
}))
|
||||
}
|
||||
|
||||
async fn list_tools(
|
||||
&self,
|
||||
_request: Request<ListToolsRequest>,
|
||||
) -> Result<Response<ListToolsResponse>, Status> {
|
||||
let tools: Vec<ToolInfo> = self
|
||||
.state
|
||||
.registry
|
||||
.list_tools()
|
||||
.iter()
|
||||
.map(|t| ToolInfo {
|
||||
name: t.get("name").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
description: t.get("description").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
category: t.get("category").and_then(|v| v.as_str()).unwrap_or("").to_string(),
|
||||
available: t.get("available").and_then(|v| v.as_bool()).unwrap_or(false),
|
||||
schema: None,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Response::new(ListToolsResponse { tools }))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn serve(state: Arc<AppState>, port: u16) -> anyhow::Result<()> {
|
||||
let addr = format!("0.0.0.0:{port}").parse()?;
|
||||
let service = BotServiceImpl { state };
|
||||
|
||||
info!(port, "Bot gRPC server starting");
|
||||
|
||||
tonic::transport::Server::builder()
|
||||
.add_service(BotServiceServer::new(service))
|
||||
.serve(addr)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,119 @@
|
||||
mod circuit_breaker;
|
||||
mod executor;
|
||||
mod grpc_server;
|
||||
pub mod tools;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
|
||||
use cxcloud_common::{
|
||||
config::{env_or, env_port},
|
||||
health::HealthResponse,
|
||||
telemetry,
|
||||
};
|
||||
|
||||
use tools::registry::ToolRegistry;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub registry: Arc<ToolRegistry>,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let otel_endpoint = env_or("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317");
|
||||
let log_level = env_or("LOG_LEVEL", "info");
|
||||
telemetry::init("bot-service", &otel_endpoint, &log_level);
|
||||
|
||||
let http_client = reqwest::Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build()?;
|
||||
|
||||
let registry = ToolRegistry::new(http_client.clone());
|
||||
|
||||
let state = Arc::new(AppState {
|
||||
registry: Arc::new(registry),
|
||||
http_client,
|
||||
});
|
||||
|
||||
// Spawn gRPC server
|
||||
let grpc_port = env_port("BOT_GRPC_PORT", 50053);
|
||||
let grpc_state = state.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = grpc_server::serve(grpc_state, grpc_port).await {
|
||||
tracing::error!(error = %e, "gRPC server failed");
|
||||
}
|
||||
});
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health))
|
||||
.route("/ready", get(ready))
|
||||
.route("/metrics", get(metrics))
|
||||
.route("/execute", post(execute))
|
||||
.with_state(state);
|
||||
|
||||
let port = env_port("BOT_HTTP_PORT", 8002);
|
||||
info!(port, grpc_port, "Bot service starting");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}")).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
telemetry::shutdown();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn health() -> Json<HealthResponse> {
|
||||
Json(HealthResponse::healthy("bot-service"))
|
||||
}
|
||||
|
||||
async fn ready(State(state): State<Arc<AppState>>) -> StatusCode {
|
||||
if state.registry.tool_count() > 0 {
|
||||
StatusCode::OK
|
||||
} else {
|
||||
StatusCode::SERVICE_UNAVAILABLE
|
||||
}
|
||||
}
|
||||
|
||||
async fn metrics(State(state): State<Arc<AppState>>) -> Json<serde_json::Value> {
|
||||
Json(serde_json::json!({
|
||||
"service": "bot-service",
|
||||
"registered_tools": state.registry.tool_count(),
|
||||
"executions": 0,
|
||||
}))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ExecutePayload {
|
||||
plan_id: String,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<serde_json::Value>,
|
||||
#[serde(default)]
|
||||
plan: Option<String>,
|
||||
}
|
||||
|
||||
async fn execute(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(body): Json<ExecutePayload>,
|
||||
) -> Json<serde_json::Value> {
|
||||
let results = executor::execute_tool_calls(&state.registry, &body.tool_calls).await;
|
||||
|
||||
Json(serde_json::json!({
|
||||
"execution_id": uuid::Uuid::now_v7().to_string(),
|
||||
"plan_id": body.plan_id,
|
||||
"results": results,
|
||||
"status": if results.iter().all(|r| r.get("success").and_then(|v| v.as_bool()).unwrap_or(false)) {
|
||||
"SUCCESS"
|
||||
} else {
|
||||
"PARTIAL_FAILURE"
|
||||
},
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,73 @@
|
||||
use anyhow::{bail, Result};
|
||||
use std::path::{Path, PathBuf};
|
||||
use tracing::info;
|
||||
|
||||
const DATA_DIR: &str = "/data";
|
||||
const MAX_READ_SIZE: u64 = 10 * 1024 * 1024; // 10 MB
|
||||
|
||||
/// Read a file from the sandboxed /data directory.
|
||||
pub async fn read(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let path_str = params
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("path parameter is required"))?;
|
||||
|
||||
let path = sanitize_path(path_str)?;
|
||||
info!(path = %path.display(), "Reading file");
|
||||
|
||||
let metadata = tokio::fs::metadata(&path).await?;
|
||||
if metadata.len() > MAX_READ_SIZE {
|
||||
bail!("File too large: {} bytes (max {})", metadata.len(), MAX_READ_SIZE);
|
||||
}
|
||||
|
||||
let content = tokio::fs::read_to_string(&path).await?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"path": path_str,
|
||||
"size": metadata.len(),
|
||||
"content": content,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Write a file to the sandboxed /data directory.
|
||||
pub async fn write(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let path_str = params
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("path parameter is required"))?;
|
||||
|
||||
let content = params
|
||||
.get("content")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("content parameter is required"))?;
|
||||
|
||||
let path = sanitize_path(path_str)?;
|
||||
info!(path = %path.display(), size = content.len(), "Writing file");
|
||||
|
||||
// Create parent directories if needed
|
||||
if let Some(parent) = path.parent() {
|
||||
tokio::fs::create_dir_all(parent).await?;
|
||||
}
|
||||
|
||||
tokio::fs::write(&path, content).await?;
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"path": path_str,
|
||||
"size": content.len(),
|
||||
"written": true,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Sanitize and validate that the path stays within /data.
|
||||
fn sanitize_path(input: &str) -> Result<PathBuf> {
|
||||
let base = Path::new(DATA_DIR);
|
||||
let full = base.join(input.trim_start_matches('/'));
|
||||
let canonical_base = base.to_path_buf();
|
||||
|
||||
// Prevent path traversal
|
||||
if !full.starts_with(&canonical_base) {
|
||||
bail!("Path traversal detected: {input}");
|
||||
}
|
||||
|
||||
Ok(full)
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
use anyhow::{bail, Result};
|
||||
use std::net::IpAddr;
|
||||
use tracing::info;
|
||||
|
||||
const MAX_RESPONSE_SIZE: usize = 1024 * 1024; // 1 MB
|
||||
|
||||
/// Execute an HTTP request tool call with SSRF protection.
|
||||
pub async fn execute(client: &reqwest::Client, params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let url = params
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("url parameter is required"))?;
|
||||
|
||||
let method = params
|
||||
.get("method")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("GET")
|
||||
.to_uppercase();
|
||||
|
||||
// SSRF protection: block private/localhost addresses
|
||||
check_ssrf(url)?;
|
||||
|
||||
info!(method, url, "Executing HTTP request");
|
||||
|
||||
let mut req = match method.as_str() {
|
||||
"GET" => client.get(url),
|
||||
"POST" => client.post(url),
|
||||
"PUT" => client.put(url),
|
||||
"PATCH" => client.patch(url),
|
||||
"DELETE" => client.delete(url),
|
||||
"HEAD" => client.head(url),
|
||||
_ => bail!("Unsupported HTTP method: {method}"),
|
||||
};
|
||||
|
||||
// Add headers
|
||||
if let Some(headers) = params.get("headers").and_then(|v| v.as_object()) {
|
||||
for (key, value) in headers {
|
||||
if let Some(val) = value.as_str() {
|
||||
req = req.header(key.as_str(), val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add body
|
||||
if let Some(body) = params.get("body") {
|
||||
req = req.json(body);
|
||||
}
|
||||
|
||||
let resp = req.send().await?;
|
||||
let status = resp.status().as_u16();
|
||||
let headers: std::collections::HashMap<String, String> = resp
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let body = resp.text().await?;
|
||||
let body_truncated = if body.len() > MAX_RESPONSE_SIZE {
|
||||
format!("{}... [truncated at {} bytes]", &body[..MAX_RESPONSE_SIZE], body.len())
|
||||
} else {
|
||||
body
|
||||
};
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"status": status,
|
||||
"headers": headers,
|
||||
"body": body_truncated,
|
||||
}))
|
||||
}
|
||||
|
||||
/// Check for SSRF: block requests to private/localhost addresses.
|
||||
fn check_ssrf(url: &str) -> Result<()> {
|
||||
let parsed = url::Url::parse(url).map_err(|_| anyhow::anyhow!("Invalid URL: {url}"))?;
|
||||
|
||||
if let Some(host) = parsed.host_str() {
|
||||
// Block localhost variants
|
||||
if host == "localhost" || host == "127.0.0.1" || host == "::1" || host == "0.0.0.0" {
|
||||
bail!("SSRF blocked: requests to localhost are not allowed");
|
||||
}
|
||||
|
||||
// Block private IP ranges
|
||||
if let Ok(ip) = host.parse::<IpAddr>() {
|
||||
match ip {
|
||||
IpAddr::V4(ipv4) => {
|
||||
if ipv4.is_private()
|
||||
|| ipv4.is_loopback()
|
||||
|| ipv4.is_link_local()
|
||||
|| ipv4.octets()[0] == 169 && ipv4.octets()[1] == 254
|
||||
{
|
||||
bail!("SSRF blocked: requests to private IP ranges are not allowed");
|
||||
}
|
||||
}
|
||||
IpAddr::V6(ipv6) => {
|
||||
if ipv6.is_loopback() {
|
||||
bail!("SSRF blocked: requests to loopback addresses are not allowed");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
pub mod file_ops;
|
||||
pub mod http_request;
|
||||
pub mod notification;
|
||||
pub mod registry;
|
||||
pub mod shell;
|
||||
@@ -0,0 +1,24 @@
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
|
||||
/// Send a notification.
|
||||
pub async fn execute(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let channel = params
|
||||
.get("channel")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("default");
|
||||
|
||||
let message = params
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("message parameter is required"))?;
|
||||
|
||||
info!(channel, message, "Sending notification");
|
||||
|
||||
// In a real implementation, this would dispatch to Slack, email, webhook, etc.
|
||||
Ok(serde_json::json!({
|
||||
"channel": channel,
|
||||
"sent": true,
|
||||
"message": message,
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
use crate::circuit_breaker::CircuitBreaker;
|
||||
use anyhow::Result;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
pub http_client: reqwest::Client,
|
||||
pub circuit_breaker: CircuitBreaker,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(http_client: reqwest::Client) -> Self {
|
||||
Self {
|
||||
http_client,
|
||||
circuit_breaker: CircuitBreaker::new(3),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_count(&self) -> usize {
|
||||
5 // http_request, file_read, file_write, notification, shell_exec
|
||||
}
|
||||
|
||||
pub fn list_tools(&self) -> Vec<serde_json::Value> {
|
||||
vec![
|
||||
serde_json::json!({"name": "http_request", "description": "Make HTTP requests", "category": "http", "available": true}),
|
||||
serde_json::json!({"name": "file_read", "description": "Read files from /data", "category": "file", "available": true}),
|
||||
serde_json::json!({"name": "file_write", "description": "Write files to /data", "category": "file", "available": true}),
|
||||
serde_json::json!({"name": "notification", "description": "Send notifications", "category": "notification", "available": true}),
|
||||
serde_json::json!({"name": "shell_exec", "description": "Execute shell commands", "category": "system", "available": true}),
|
||||
]
|
||||
}
|
||||
|
||||
/// Dispatch execution to the appropriate tool handler.
|
||||
pub async fn execute(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
params: &serde_json::Value,
|
||||
) -> Result<serde_json::Value> {
|
||||
match tool_name {
|
||||
"http_request" => super::http_request::execute(&self.http_client, params).await,
|
||||
"file_read" => super::file_ops::read(params).await,
|
||||
"file_write" => super::file_ops::write(params).await,
|
||||
"notification" => super::notification::execute(params).await,
|
||||
"shell_exec" => super::shell::execute(params).await,
|
||||
_ => anyhow::bail!("Unknown tool: {tool_name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
use anyhow::Result;
|
||||
use tracing::info;
|
||||
|
||||
/// Execute a shell command.
|
||||
pub async fn execute(params: &serde_json::Value) -> Result<serde_json::Value> {
|
||||
let command = params
|
||||
.get("command")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("command parameter is required"))?;
|
||||
|
||||
let timeout_secs = params
|
||||
.get("timeout_seconds")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(30);
|
||||
|
||||
info!(command, timeout_secs, "Executing shell command");
|
||||
|
||||
let output = tokio::time::timeout(
|
||||
std::time::Duration::from_secs(timeout_secs),
|
||||
tokio::process::Command::new("sh")
|
||||
.arg("-c")
|
||||
.arg(command)
|
||||
.output(),
|
||||
)
|
||||
.await
|
||||
.map_err(|_| anyhow::anyhow!("Command timed out after {timeout_secs}s"))??;
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
|
||||
let exit_code = output.status.code().unwrap_or(-1);
|
||||
|
||||
Ok(serde_json::json!({
|
||||
"exit_code": exit_code,
|
||||
"stdout": stdout,
|
||||
"stderr": stderr,
|
||||
"success": output.status.success(),
|
||||
}))
|
||||
}
|
||||
@@ -0,0 +1,7 @@
|
||||
{
|
||||
"git": {
|
||||
"sha1": "927a31cbf65f78c3ef6b729631b2fc35335afe06",
|
||||
"dirty": true
|
||||
},
|
||||
"path_in_vcs": "services/cxcloud-rs/crates/common"
|
||||
}
|
||||
Generated
+2282
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,92 @@
|
||||
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||
#
|
||||
# When uploading crates to the registry Cargo will automatically
|
||||
# "normalize" Cargo.toml files for maximal compatibility
|
||||
# with all versions of Cargo and also rewrite `path` dependencies
|
||||
# to registry (e.g., crates.io) dependencies.
|
||||
#
|
||||
# If you are reading this file be aware that the original Cargo.toml
|
||||
# will likely look very different (and much more reasonable).
|
||||
# See Cargo.toml.orig for the original contents.
|
||||
|
||||
[package]
|
||||
edition = "2021"
|
||||
name = "cxcloud-common"
|
||||
version = "0.1.0"
|
||||
build = false
|
||||
publish = ["cxai"]
|
||||
autolib = false
|
||||
autobins = false
|
||||
autoexamples = false
|
||||
autotests = false
|
||||
autobenches = false
|
||||
readme = false
|
||||
|
||||
[lib]
|
||||
name = "cxcloud_common"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[dependencies.anyhow]
|
||||
version = "1"
|
||||
|
||||
[dependencies.chrono]
|
||||
version = "0.4"
|
||||
features = ["serde"]
|
||||
|
||||
[dependencies.cxcloud-proto]
|
||||
version = "0.1.0"
|
||||
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||
|
||||
[dependencies.deadpool-redis]
|
||||
version = "0.18"
|
||||
|
||||
[dependencies.opentelemetry]
|
||||
version = "0.24"
|
||||
|
||||
[dependencies.opentelemetry-otlp]
|
||||
version = "0.17"
|
||||
|
||||
[dependencies.opentelemetry_sdk]
|
||||
version = "0.24"
|
||||
features = ["rt-tokio"]
|
||||
|
||||
[dependencies.redis]
|
||||
version = "0.27"
|
||||
features = [
|
||||
"tokio-comp",
|
||||
"connection-manager",
|
||||
]
|
||||
|
||||
[dependencies.serde]
|
||||
version = "1"
|
||||
features = ["derive"]
|
||||
|
||||
[dependencies.serde_json]
|
||||
version = "1"
|
||||
|
||||
[dependencies.thiserror]
|
||||
version = "2"
|
||||
|
||||
[dependencies.tokio]
|
||||
version = "1"
|
||||
features = ["full"]
|
||||
|
||||
[dependencies.tracing]
|
||||
version = "0.1"
|
||||
|
||||
[dependencies.tracing-opentelemetry]
|
||||
version = "0.25"
|
||||
|
||||
[dependencies.tracing-subscriber]
|
||||
version = "0.3"
|
||||
features = [
|
||||
"env-filter",
|
||||
"json",
|
||||
]
|
||||
|
||||
[dependencies.uuid]
|
||||
version = "1"
|
||||
features = [
|
||||
"v7",
|
||||
"serde",
|
||||
]
|
||||
Generated
+23
@@ -0,0 +1,23 @@
|
||||
[package]
|
||||
name = "cxcloud-common"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
publish.workspace = true
|
||||
|
||||
[dependencies]
|
||||
cxcloud-proto = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
redis = { workspace = true }
|
||||
deadpool-redis = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
opentelemetry = { workspace = true }
|
||||
opentelemetry_sdk = { workspace = true }
|
||||
opentelemetry-otlp = { workspace = true }
|
||||
tracing-opentelemetry = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
@@ -0,0 +1,127 @@
|
||||
use std::env;
|
||||
|
||||
/// Read an environment variable with a default fallback.
|
||||
pub fn env_or(key: &str, default: &str) -> String {
|
||||
env::var(key).unwrap_or_else(|_| default.to_string())
|
||||
}
|
||||
|
||||
/// Read a required environment variable. Panics if missing.
|
||||
pub fn env_required(key: &str) -> String {
|
||||
env::var(key).unwrap_or_else(|_| panic!("{key} environment variable is required"))
|
||||
}
|
||||
|
||||
/// Read an environment variable as a u16 port number.
|
||||
pub fn env_port(key: &str, default: u16) -> u16 {
|
||||
env::var(key)
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Read an environment variable as a u64.
|
||||
pub fn env_u64(key: &str, default: u64) -> u64 {
|
||||
env::var(key)
|
||||
.ok()
|
||||
.and_then(|v| v.parse().ok())
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
/// Common service configuration shared by all services.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServiceConfig {
|
||||
pub redis_url: String,
|
||||
pub otel_endpoint: String,
|
||||
pub otel_service_name: String,
|
||||
pub log_level: String,
|
||||
}
|
||||
|
||||
impl ServiceConfig {
|
||||
pub fn from_env(service_name: &str) -> Self {
|
||||
Self {
|
||||
redis_url: env_or("REDIS_URL", "redis://localhost:6379"),
|
||||
otel_endpoint: env_or("OTEL_EXPORTER_OTLP_ENDPOINT", "http://localhost:4317"),
|
||||
otel_service_name: env_or("OTEL_SERVICE_NAME", service_name),
|
||||
log_level: env_or("LOG_LEVEL", "info"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn env_or_returns_default_when_unset() {
|
||||
// Use a key unlikely to be set
|
||||
let val = env_or("CXCLOUD_TEST_UNSET_VAR_12345", "fallback");
|
||||
assert_eq!(val, "fallback");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_or_returns_value_when_set() {
|
||||
env::set_var("CXCLOUD_TEST_ENV_OR", "custom");
|
||||
let val = env_or("CXCLOUD_TEST_ENV_OR", "default");
|
||||
assert_eq!(val, "custom");
|
||||
env::remove_var("CXCLOUD_TEST_ENV_OR");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_port_parses_valid_port() {
|
||||
env::set_var("CXCLOUD_TEST_PORT", "9090");
|
||||
let port = env_port("CXCLOUD_TEST_PORT", 3000);
|
||||
assert_eq!(port, 9090);
|
||||
env::remove_var("CXCLOUD_TEST_PORT");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_port_returns_default_for_invalid() {
|
||||
env::set_var("CXCLOUD_TEST_PORT_BAD", "not-a-number");
|
||||
let port = env_port("CXCLOUD_TEST_PORT_BAD", 3000);
|
||||
assert_eq!(port, 3000);
|
||||
env::remove_var("CXCLOUD_TEST_PORT_BAD");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_port_returns_default_when_unset() {
|
||||
let port = env_port("CXCLOUD_TEST_PORT_UNSET_99999", 8080);
|
||||
assert_eq!(port, 8080);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_u64_parses_valid_value() {
|
||||
env::set_var("CXCLOUD_TEST_U64", "42");
|
||||
let val = env_u64("CXCLOUD_TEST_U64", 10);
|
||||
assert_eq!(val, 42);
|
||||
env::remove_var("CXCLOUD_TEST_U64");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_u64_returns_default_when_unset() {
|
||||
let val = env_u64("CXCLOUD_TEST_U64_UNSET_99999", 100);
|
||||
assert_eq!(val, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "environment variable is required")]
|
||||
fn env_required_panics_when_missing() {
|
||||
env_required("CXCLOUD_TEST_REQUIRED_UNSET_99999");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_required_returns_value_when_set() {
|
||||
env::set_var("CXCLOUD_TEST_REQUIRED", "present");
|
||||
let val = env_required("CXCLOUD_TEST_REQUIRED");
|
||||
assert_eq!(val, "present");
|
||||
env::remove_var("CXCLOUD_TEST_REQUIRED");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn service_config_from_env_defaults() {
|
||||
let cfg = ServiceConfig::from_env("test-svc");
|
||||
assert!(!cfg.redis_url.is_empty());
|
||||
assert!(!cfg.otel_endpoint.is_empty());
|
||||
assert!(!cfg.log_level.is_empty());
|
||||
// service name should default if OTEL_SERVICE_NAME not set
|
||||
assert!(!cfg.otel_service_name.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum CxError {
|
||||
#[error("Redis error: {0}")]
|
||||
Redis(#[from] redis::RedisError),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("Service unavailable: {0}")]
|
||||
Unavailable(String),
|
||||
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
#[error("Unauthorized: {0}")]
|
||||
Unauthorized(String),
|
||||
|
||||
#[error("Rate limited")]
|
||||
RateLimited,
|
||||
|
||||
#[error("Upstream error: {status} {message}")]
|
||||
Upstream { status: u16, message: String },
|
||||
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl CxError {
|
||||
pub fn status_code(&self) -> u16 {
|
||||
match self {
|
||||
Self::Unauthorized(_) => 401,
|
||||
Self::RateLimited => 429,
|
||||
Self::NotFound(_) => 404,
|
||||
Self::Unavailable(_) => 503,
|
||||
Self::Upstream { status, .. } => *status,
|
||||
_ => 500,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn status_codes_are_correct() {
|
||||
assert_eq!(CxError::Unauthorized("x".into()).status_code(), 401);
|
||||
assert_eq!(CxError::RateLimited.status_code(), 429);
|
||||
assert_eq!(CxError::NotFound("x".into()).status_code(), 404);
|
||||
assert_eq!(CxError::Unavailable("x".into()).status_code(), 503);
|
||||
assert_eq!(
|
||||
CxError::Upstream { status: 502, message: "bad".into() }.status_code(),
|
||||
502
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn internal_errors_default_to_500() {
|
||||
assert_eq!(CxError::Internal("oops".into()).status_code(), 500);
|
||||
assert_eq!(CxError::Config("bad config".into()).status_code(), 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_display_messages() {
|
||||
let e = CxError::NotFound("item-123".into());
|
||||
assert_eq!(e.to_string(), "Not found: item-123");
|
||||
|
||||
let e = CxError::RateLimited;
|
||||
assert_eq!(e.to_string(), "Rate limited");
|
||||
|
||||
let e = CxError::Upstream { status: 503, message: "timeout".into() };
|
||||
assert_eq!(e.to_string(), "Upstream error: 503 timeout");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn redis_error_converts() {
|
||||
let redis_err = redis::RedisError::from((redis::ErrorKind::IoError, "connection refused"));
|
||||
let cx_err = CxError::from(redis_err);
|
||||
assert_eq!(cx_err.status_code(), 500);
|
||||
assert!(cx_err.to_string().contains("Redis error"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_error_converts() {
|
||||
let json_err = serde_json::from_str::<serde_json::Value>("{{bad}}").unwrap_err();
|
||||
let cx_err = CxError::from(json_err);
|
||||
assert_eq!(cx_err.status_code(), 500);
|
||||
assert!(cx_err.to_string().contains("Serialization error"));
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user