vendor: update cargo-cxai-ml-0.1.0

This commit is contained in:
cx-git-agent
2026-04-26 16:48:13 +00:00
committed by GitHub
parent 04fe2c4828
commit 60fc0bea5b
8 changed files with 2474 additions and 3 deletions
+6
View File
@@ -0,0 +1,6 @@
{
"git": {
"sha1": "d878deb8441897ecdd416011b49d2d2f6112e867"
},
"path_in_vcs": "crates/cxai-ml"
}
Generated
+1942
View File
File diff suppressed because it is too large Load Diff
+74
View File
@@ -0,0 +1,74 @@
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
#
# When uploading crates to the registry Cargo will automatically
# "normalize" Cargo.toml files for maximal compatibility
# with all versions of Cargo and also rewrite `path` dependencies
# to registry (e.g., crates.io) dependencies.
#
# If you are reading this file be aware that the original Cargo.toml
# will likely look very different (and much more reasonable).
# See Cargo.toml.orig for the original contents.
[package]
edition = "2024"
name = "cxai-ml"
version = "0.1.0"
authors = ["CxAI-LLM <agent@cxai-studio.com>"]
build = false
autolib = false
autobins = false
autoexamples = false
autotests = false
autobenches = false
description = "ML inference client — model catalog, prediction, anomaly detection, regime classification"
homepage = "https://cxllm.io"
readme = false
license = "MIT"
repository = "https://git.cxllm-studio.com/CxAI-LLM/CxAI.Rust"
resolver = "2"
[lib]
name = "cxai_ml"
path = "src/lib.rs"
[dependencies.chrono]
version = "0.4"
features = ["serde"]
[dependencies.cxai-sdk]
version = "0.1.0"
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
[dependencies.reqwest]
version = "0.12"
features = [
"json",
"rustls-tls",
"stream",
]
default-features = false
[dependencies.serde]
version = "1"
features = ["derive"]
[dependencies.serde_json]
version = "1"
[dependencies.thiserror]
version = "2"
[dependencies.tokio]
version = "1"
features = ["full"]
[dependencies.tracing]
version = "0.1"
[dev-dependencies.tokio]
version = "1"
features = [
"full",
"test-util",
"macros",
]
+22
View File
@@ -0,0 +1,22 @@
[package]
name = "cxai-ml"
description = "ML inference client — model catalog, prediction, anomaly detection, regime classification"
version.workspace = true
edition.workspace = true
license.workspace = true
repository.workspace = true
homepage.workspace = true
authors.workspace = true
[dependencies]
cxai-sdk = { workspace = true }
tokio = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
chrono = { workspace = true }
[dev-dependencies]
tokio = { workspace = true, features = ["test-util", "macros"] }
-3
View File
@@ -1,3 +0,0 @@
# cargo-cxai-ml-0.1.0
Cargo crate: cxai-ml-0.1.0
+315
View File
@@ -0,0 +1,315 @@
use serde::{Deserialize, Serialize};
/// The 26-model catalog across 8 domains, matching the .NET ModelCatalogService.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ModelInfo {
pub id: String,
pub name: String,
pub domain: String,
pub model_type: String,
pub description: String,
pub status: ModelStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ModelStatus {
Active,
Training,
Deprecated,
Offline,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DomainInfo {
pub name: String,
pub model_count: usize,
pub description: String,
}
pub struct ModelCatalog;
impl ModelCatalog {
/// Built-in model definitions across all domains.
pub fn all_models() -> Vec<ModelInfo> {
vec![
// Price Forecasting
model(
"dam-xgb",
"DAM XGBoost",
"price_forecasting",
"xgboost",
"Day-ahead market price prediction",
),
model(
"dam-lstm",
"DAM LSTM",
"price_forecasting",
"lstm",
"DAM price via LSTM neural network",
),
model(
"dam-transformer",
"DAM Transformer",
"price_forecasting",
"transformer",
"DAM price via transformer model",
),
model(
"rt-xgb",
"RT XGBoost",
"price_forecasting",
"xgboost",
"Real-time price prediction",
),
model(
"rt-lstm",
"RT LSTM",
"price_forecasting",
"lstm",
"RT price via LSTM",
),
// Demand Forecasting
model(
"demand-prophet",
"Demand Prophet",
"demand_forecasting",
"prophet",
"System demand forecasting",
),
model(
"demand-xgb",
"Demand XGBoost",
"demand_forecasting",
"xgboost",
"Demand via gradient boosting",
),
model(
"demand-nn",
"Demand Neural",
"demand_forecasting",
"neural_net",
"Demand via neural network",
),
// Renewable Forecasting
model(
"wind-xgb",
"Wind XGBoost",
"renewable_forecasting",
"xgboost",
"Wind generation prediction",
),
model(
"wind-lstm",
"Wind LSTM",
"renewable_forecasting",
"lstm",
"Wind via LSTM",
),
model(
"solar-xgb",
"Solar XGBoost",
"renewable_forecasting",
"xgboost",
"Solar generation prediction",
),
model(
"solar-lstm",
"Solar LSTM",
"renewable_forecasting",
"lstm",
"Solar via LSTM",
),
// Anomaly Detection
model(
"anomaly-iforest",
"IsolationForest",
"anomaly_detection",
"isolation_forest",
"Market anomaly detection",
),
model(
"anomaly-autoencoder",
"Autoencoder",
"anomaly_detection",
"autoencoder",
"Deep anomaly detection",
),
// Regime Classification
model(
"regime-rf",
"Regime RF",
"regime_classification",
"random_forest",
"4-class price regime",
),
model(
"regime-xgb",
"Regime XGBoost",
"regime_classification",
"xgboost",
"Regime via gradient boosting",
),
// Battery Optimization
model(
"bess-dispatch",
"BESS Dispatch",
"battery_optimization",
"linear_program",
"Battery dispatch optimization",
),
model(
"bess-degradation",
"BESS Degradation",
"battery_optimization",
"regression",
"Battery degradation model",
),
model(
"bess-revenue",
"BESS Revenue",
"battery_optimization",
"mixed_integer",
"Revenue maximization",
),
// Congestion Analysis
model(
"congestion-rf",
"Congestion RF",
"congestion_analysis",
"random_forest",
"Congestion prediction",
),
model(
"congestion-gnn",
"Congestion GNN",
"congestion_analysis",
"graph_neural_net",
"Grid topology model",
),
// Risk Analysis
model(
"var-historical",
"VaR Historical",
"risk_analysis",
"statistical",
"Historical value-at-risk",
),
model(
"var-monte-carlo",
"VaR Monte Carlo",
"risk_analysis",
"simulation",
"Monte Carlo VaR",
),
model(
"cvar-parametric",
"CVaR Parametric",
"risk_analysis",
"statistical",
"Conditional VaR",
),
model(
"volatility-garch",
"GARCH Volatility",
"risk_analysis",
"garch",
"Price volatility model",
),
model(
"correlation-dcc",
"DCC Correlation",
"risk_analysis",
"dcc_garch",
"Dynamic conditional correlation",
),
]
}
/// List all 8 model domains.
pub fn domains() -> Vec<DomainInfo> {
vec![
DomainInfo {
name: "price_forecasting".into(),
model_count: 5,
description: "DAM and RT price prediction".into(),
},
DomainInfo {
name: "demand_forecasting".into(),
model_count: 3,
description: "System demand forecasting".into(),
},
DomainInfo {
name: "renewable_forecasting".into(),
model_count: 4,
description: "Wind and solar generation".into(),
},
DomainInfo {
name: "anomaly_detection".into(),
model_count: 2,
description: "Market anomaly detection".into(),
},
DomainInfo {
name: "regime_classification".into(),
model_count: 2,
description: "Price regime classification".into(),
},
DomainInfo {
name: "battery_optimization".into(),
model_count: 3,
description: "BESS dispatch and revenue".into(),
},
DomainInfo {
name: "congestion_analysis".into(),
model_count: 2,
description: "Grid congestion prediction".into(),
},
DomainInfo {
name: "risk_analysis".into(),
model_count: 5,
description: "VaR, CVaR, and volatility".into(),
},
]
}
/// Get models for a specific domain.
pub fn by_domain(domain: &str) -> Vec<ModelInfo> {
Self::all_models()
.into_iter()
.filter(|m| m.domain == domain)
.collect()
}
}
fn model(id: &str, name: &str, domain: &str, model_type: &str, desc: &str) -> ModelInfo {
ModelInfo {
id: id.into(),
name: name.into(),
domain: domain.into(),
model_type: model_type.into(),
description: desc.into(),
status: ModelStatus::Active,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn catalog_has_26_models() {
assert_eq!(ModelCatalog::all_models().len(), 26);
}
#[test]
fn catalog_has_8_domains() {
assert_eq!(ModelCatalog::domains().len(), 8);
}
#[test]
fn filter_by_domain() {
let price_models = ModelCatalog::by_domain("price_forecasting");
assert_eq!(price_models.len(), 5);
}
}
+110
View File
@@ -0,0 +1,110 @@
use serde::{Deserialize, Serialize};
use tracing::debug;
/// High-level ML inference client wrapping the platform ML service.
pub struct MlInferenceClient {
platform: cxai_sdk::CxPlatform,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PredictionRequest {
pub model: String,
pub features: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PredictionResult {
pub model: String,
pub prediction: serde_json::Value,
pub confidence: Option<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AnomalyResult {
pub is_anomaly: bool,
pub score: f64,
pub features: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RegimeResult {
pub regime: String,
pub class_id: u8,
pub probabilities: Vec<f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ForecastResult {
pub settlement_point: String,
pub hours_ahead: u32,
pub predictions: Vec<HourlyForecast>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HourlyForecast {
pub hour: u8,
pub price: f64,
pub lower_bound: f64,
pub upper_bound: f64,
}
impl MlInferenceClient {
pub fn new(platform: cxai_sdk::CxPlatform) -> Self {
Self { platform }
}
/// Single model prediction.
pub async fn predict(
&self,
request: &PredictionRequest,
) -> cxai_sdk::CxResult<PredictionResult> {
debug!(model = %request.model, "ML predict");
let raw = self
.platform
.ml()
.predict(&request.model, request.features.clone())
.await?;
Ok(serde_json::from_value(raw)?)
}
/// Detect anomalies in feature data.
pub async fn detect_anomalies(
&self,
features: serde_json::Value,
) -> cxai_sdk::CxResult<AnomalyResult> {
debug!("ML detect anomalies");
let raw = self.platform.ml().detect_anomalies(features).await?;
Ok(serde_json::from_value(raw)?)
}
/// Classify price regime.
pub async fn classify_regime(
&self,
features: serde_json::Value,
) -> cxai_sdk::CxResult<RegimeResult> {
debug!("ML classify regime");
let raw = self.platform.ml().classify_regime(features).await?;
Ok(serde_json::from_value(raw)?)
}
/// Price forecast for a settlement point.
pub async fn price_forecast(
&self,
settlement_point: &str,
hours_ahead: u32,
) -> cxai_sdk::CxResult<ForecastResult> {
debug!(settlement_point, hours_ahead, "ML price forecast");
let raw = self
.platform
.ml()
.price_forecast(settlement_point, hours_ahead)
.await?;
Ok(serde_json::from_value(raw)?)
}
}
+5
View File
@@ -0,0 +1,5 @@
pub mod catalog;
pub mod inference;
pub use catalog::ModelCatalog;
pub use inference::MlInferenceClient;