vendor: update cargo-cxai-ml-0.1.0
This commit is contained in:
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"git": {
|
||||||
|
"sha1": "d878deb8441897ecdd416011b49d2d2f6112e867"
|
||||||
|
},
|
||||||
|
"path_in_vcs": "crates/cxai-ml"
|
||||||
|
}
|
||||||
Generated
+1942
File diff suppressed because it is too large
Load Diff
+74
@@ -0,0 +1,74 @@
|
|||||||
|
# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
|
||||||
|
#
|
||||||
|
# When uploading crates to the registry Cargo will automatically
|
||||||
|
# "normalize" Cargo.toml files for maximal compatibility
|
||||||
|
# with all versions of Cargo and also rewrite `path` dependencies
|
||||||
|
# to registry (e.g., crates.io) dependencies.
|
||||||
|
#
|
||||||
|
# If you are reading this file be aware that the original Cargo.toml
|
||||||
|
# will likely look very different (and much more reasonable).
|
||||||
|
# See Cargo.toml.orig for the original contents.
|
||||||
|
|
||||||
|
[package]
|
||||||
|
edition = "2024"
|
||||||
|
name = "cxai-ml"
|
||||||
|
version = "0.1.0"
|
||||||
|
authors = ["CxAI-LLM <agent@cxai-studio.com>"]
|
||||||
|
build = false
|
||||||
|
autolib = false
|
||||||
|
autobins = false
|
||||||
|
autoexamples = false
|
||||||
|
autotests = false
|
||||||
|
autobenches = false
|
||||||
|
description = "ML inference client — model catalog, prediction, anomaly detection, regime classification"
|
||||||
|
homepage = "https://cxllm.io"
|
||||||
|
readme = false
|
||||||
|
license = "MIT"
|
||||||
|
repository = "https://git.cxllm-studio.com/CxAI-LLM/CxAI.Rust"
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "cxai_ml"
|
||||||
|
path = "src/lib.rs"
|
||||||
|
|
||||||
|
[dependencies.chrono]
|
||||||
|
version = "0.4"
|
||||||
|
features = ["serde"]
|
||||||
|
|
||||||
|
[dependencies.cxai-sdk]
|
||||||
|
version = "0.1.0"
|
||||||
|
registry-index = "sparse+https://git.cxllm-studio.com/api/packages/CxAI-LLM/cargo/"
|
||||||
|
|
||||||
|
[dependencies.reqwest]
|
||||||
|
version = "0.12"
|
||||||
|
features = [
|
||||||
|
"json",
|
||||||
|
"rustls-tls",
|
||||||
|
"stream",
|
||||||
|
]
|
||||||
|
default-features = false
|
||||||
|
|
||||||
|
[dependencies.serde]
|
||||||
|
version = "1"
|
||||||
|
features = ["derive"]
|
||||||
|
|
||||||
|
[dependencies.serde_json]
|
||||||
|
version = "1"
|
||||||
|
|
||||||
|
[dependencies.thiserror]
|
||||||
|
version = "2"
|
||||||
|
|
||||||
|
[dependencies.tokio]
|
||||||
|
version = "1"
|
||||||
|
features = ["full"]
|
||||||
|
|
||||||
|
[dependencies.tracing]
|
||||||
|
version = "0.1"
|
||||||
|
|
||||||
|
[dev-dependencies.tokio]
|
||||||
|
version = "1"
|
||||||
|
features = [
|
||||||
|
"full",
|
||||||
|
"test-util",
|
||||||
|
"macros",
|
||||||
|
]
|
||||||
Generated
+22
@@ -0,0 +1,22 @@
|
|||||||
|
[package]
|
||||||
|
name = "cxai-ml"
|
||||||
|
description = "ML inference client — model catalog, prediction, anomaly detection, regime classification"
|
||||||
|
version.workspace = true
|
||||||
|
edition.workspace = true
|
||||||
|
license.workspace = true
|
||||||
|
repository.workspace = true
|
||||||
|
homepage.workspace = true
|
||||||
|
authors.workspace = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
cxai-sdk = { workspace = true }
|
||||||
|
tokio = { workspace = true }
|
||||||
|
reqwest = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
|
serde_json = { workspace = true }
|
||||||
|
thiserror = { workspace = true }
|
||||||
|
tracing = { workspace = true }
|
||||||
|
chrono = { workspace = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
tokio = { workspace = true, features = ["test-util", "macros"] }
|
||||||
+315
@@ -0,0 +1,315 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// The 26-model catalog across 8 domains, matching the .NET ModelCatalogService.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ModelInfo {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub domain: String,
|
||||||
|
pub model_type: String,
|
||||||
|
pub description: String,
|
||||||
|
pub status: ModelStatus,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum ModelStatus {
|
||||||
|
Active,
|
||||||
|
Training,
|
||||||
|
Deprecated,
|
||||||
|
Offline,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DomainInfo {
|
||||||
|
pub name: String,
|
||||||
|
pub model_count: usize,
|
||||||
|
pub description: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ModelCatalog;
|
||||||
|
|
||||||
|
impl ModelCatalog {
|
||||||
|
/// Built-in model definitions across all domains.
|
||||||
|
pub fn all_models() -> Vec<ModelInfo> {
|
||||||
|
vec![
|
||||||
|
// Price Forecasting
|
||||||
|
model(
|
||||||
|
"dam-xgb",
|
||||||
|
"DAM XGBoost",
|
||||||
|
"price_forecasting",
|
||||||
|
"xgboost",
|
||||||
|
"Day-ahead market price prediction",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"dam-lstm",
|
||||||
|
"DAM LSTM",
|
||||||
|
"price_forecasting",
|
||||||
|
"lstm",
|
||||||
|
"DAM price via LSTM neural network",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"dam-transformer",
|
||||||
|
"DAM Transformer",
|
||||||
|
"price_forecasting",
|
||||||
|
"transformer",
|
||||||
|
"DAM price via transformer model",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"rt-xgb",
|
||||||
|
"RT XGBoost",
|
||||||
|
"price_forecasting",
|
||||||
|
"xgboost",
|
||||||
|
"Real-time price prediction",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"rt-lstm",
|
||||||
|
"RT LSTM",
|
||||||
|
"price_forecasting",
|
||||||
|
"lstm",
|
||||||
|
"RT price via LSTM",
|
||||||
|
),
|
||||||
|
// Demand Forecasting
|
||||||
|
model(
|
||||||
|
"demand-prophet",
|
||||||
|
"Demand Prophet",
|
||||||
|
"demand_forecasting",
|
||||||
|
"prophet",
|
||||||
|
"System demand forecasting",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"demand-xgb",
|
||||||
|
"Demand XGBoost",
|
||||||
|
"demand_forecasting",
|
||||||
|
"xgboost",
|
||||||
|
"Demand via gradient boosting",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"demand-nn",
|
||||||
|
"Demand Neural",
|
||||||
|
"demand_forecasting",
|
||||||
|
"neural_net",
|
||||||
|
"Demand via neural network",
|
||||||
|
),
|
||||||
|
// Renewable Forecasting
|
||||||
|
model(
|
||||||
|
"wind-xgb",
|
||||||
|
"Wind XGBoost",
|
||||||
|
"renewable_forecasting",
|
||||||
|
"xgboost",
|
||||||
|
"Wind generation prediction",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"wind-lstm",
|
||||||
|
"Wind LSTM",
|
||||||
|
"renewable_forecasting",
|
||||||
|
"lstm",
|
||||||
|
"Wind via LSTM",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"solar-xgb",
|
||||||
|
"Solar XGBoost",
|
||||||
|
"renewable_forecasting",
|
||||||
|
"xgboost",
|
||||||
|
"Solar generation prediction",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"solar-lstm",
|
||||||
|
"Solar LSTM",
|
||||||
|
"renewable_forecasting",
|
||||||
|
"lstm",
|
||||||
|
"Solar via LSTM",
|
||||||
|
),
|
||||||
|
// Anomaly Detection
|
||||||
|
model(
|
||||||
|
"anomaly-iforest",
|
||||||
|
"IsolationForest",
|
||||||
|
"anomaly_detection",
|
||||||
|
"isolation_forest",
|
||||||
|
"Market anomaly detection",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"anomaly-autoencoder",
|
||||||
|
"Autoencoder",
|
||||||
|
"anomaly_detection",
|
||||||
|
"autoencoder",
|
||||||
|
"Deep anomaly detection",
|
||||||
|
),
|
||||||
|
// Regime Classification
|
||||||
|
model(
|
||||||
|
"regime-rf",
|
||||||
|
"Regime RF",
|
||||||
|
"regime_classification",
|
||||||
|
"random_forest",
|
||||||
|
"4-class price regime",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"regime-xgb",
|
||||||
|
"Regime XGBoost",
|
||||||
|
"regime_classification",
|
||||||
|
"xgboost",
|
||||||
|
"Regime via gradient boosting",
|
||||||
|
),
|
||||||
|
// Battery Optimization
|
||||||
|
model(
|
||||||
|
"bess-dispatch",
|
||||||
|
"BESS Dispatch",
|
||||||
|
"battery_optimization",
|
||||||
|
"linear_program",
|
||||||
|
"Battery dispatch optimization",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"bess-degradation",
|
||||||
|
"BESS Degradation",
|
||||||
|
"battery_optimization",
|
||||||
|
"regression",
|
||||||
|
"Battery degradation model",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"bess-revenue",
|
||||||
|
"BESS Revenue",
|
||||||
|
"battery_optimization",
|
||||||
|
"mixed_integer",
|
||||||
|
"Revenue maximization",
|
||||||
|
),
|
||||||
|
// Congestion Analysis
|
||||||
|
model(
|
||||||
|
"congestion-rf",
|
||||||
|
"Congestion RF",
|
||||||
|
"congestion_analysis",
|
||||||
|
"random_forest",
|
||||||
|
"Congestion prediction",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"congestion-gnn",
|
||||||
|
"Congestion GNN",
|
||||||
|
"congestion_analysis",
|
||||||
|
"graph_neural_net",
|
||||||
|
"Grid topology model",
|
||||||
|
),
|
||||||
|
// Risk Analysis
|
||||||
|
model(
|
||||||
|
"var-historical",
|
||||||
|
"VaR Historical",
|
||||||
|
"risk_analysis",
|
||||||
|
"statistical",
|
||||||
|
"Historical value-at-risk",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"var-monte-carlo",
|
||||||
|
"VaR Monte Carlo",
|
||||||
|
"risk_analysis",
|
||||||
|
"simulation",
|
||||||
|
"Monte Carlo VaR",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"cvar-parametric",
|
||||||
|
"CVaR Parametric",
|
||||||
|
"risk_analysis",
|
||||||
|
"statistical",
|
||||||
|
"Conditional VaR",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"volatility-garch",
|
||||||
|
"GARCH Volatility",
|
||||||
|
"risk_analysis",
|
||||||
|
"garch",
|
||||||
|
"Price volatility model",
|
||||||
|
),
|
||||||
|
model(
|
||||||
|
"correlation-dcc",
|
||||||
|
"DCC Correlation",
|
||||||
|
"risk_analysis",
|
||||||
|
"dcc_garch",
|
||||||
|
"Dynamic conditional correlation",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// List all 8 model domains.
|
||||||
|
pub fn domains() -> Vec<DomainInfo> {
|
||||||
|
vec![
|
||||||
|
DomainInfo {
|
||||||
|
name: "price_forecasting".into(),
|
||||||
|
model_count: 5,
|
||||||
|
description: "DAM and RT price prediction".into(),
|
||||||
|
},
|
||||||
|
DomainInfo {
|
||||||
|
name: "demand_forecasting".into(),
|
||||||
|
model_count: 3,
|
||||||
|
description: "System demand forecasting".into(),
|
||||||
|
},
|
||||||
|
DomainInfo {
|
||||||
|
name: "renewable_forecasting".into(),
|
||||||
|
model_count: 4,
|
||||||
|
description: "Wind and solar generation".into(),
|
||||||
|
},
|
||||||
|
DomainInfo {
|
||||||
|
name: "anomaly_detection".into(),
|
||||||
|
model_count: 2,
|
||||||
|
description: "Market anomaly detection".into(),
|
||||||
|
},
|
||||||
|
DomainInfo {
|
||||||
|
name: "regime_classification".into(),
|
||||||
|
model_count: 2,
|
||||||
|
description: "Price regime classification".into(),
|
||||||
|
},
|
||||||
|
DomainInfo {
|
||||||
|
name: "battery_optimization".into(),
|
||||||
|
model_count: 3,
|
||||||
|
description: "BESS dispatch and revenue".into(),
|
||||||
|
},
|
||||||
|
DomainInfo {
|
||||||
|
name: "congestion_analysis".into(),
|
||||||
|
model_count: 2,
|
||||||
|
description: "Grid congestion prediction".into(),
|
||||||
|
},
|
||||||
|
DomainInfo {
|
||||||
|
name: "risk_analysis".into(),
|
||||||
|
model_count: 5,
|
||||||
|
description: "VaR, CVaR, and volatility".into(),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get models for a specific domain.
|
||||||
|
pub fn by_domain(domain: &str) -> Vec<ModelInfo> {
|
||||||
|
Self::all_models()
|
||||||
|
.into_iter()
|
||||||
|
.filter(|m| m.domain == domain)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(id: &str, name: &str, domain: &str, model_type: &str, desc: &str) -> ModelInfo {
|
||||||
|
ModelInfo {
|
||||||
|
id: id.into(),
|
||||||
|
name: name.into(),
|
||||||
|
domain: domain.into(),
|
||||||
|
model_type: model_type.into(),
|
||||||
|
description: desc.into(),
|
||||||
|
status: ModelStatus::Active,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn catalog_has_26_models() {
|
||||||
|
assert_eq!(ModelCatalog::all_models().len(), 26);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn catalog_has_8_domains() {
|
||||||
|
assert_eq!(ModelCatalog::domains().len(), 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn filter_by_domain() {
|
||||||
|
let price_models = ModelCatalog::by_domain("price_forecasting");
|
||||||
|
assert_eq!(price_models.len(), 5);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,110 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
/// High-level ML inference client wrapping the platform ML service.
|
||||||
|
pub struct MlInferenceClient {
|
||||||
|
platform: cxai_sdk::CxPlatform,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PredictionRequest {
|
||||||
|
pub model: String,
|
||||||
|
pub features: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PredictionResult {
|
||||||
|
pub model: String,
|
||||||
|
pub prediction: serde_json::Value,
|
||||||
|
pub confidence: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct AnomalyResult {
|
||||||
|
pub is_anomaly: bool,
|
||||||
|
pub score: f64,
|
||||||
|
pub features: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct RegimeResult {
|
||||||
|
pub regime: String,
|
||||||
|
pub class_id: u8,
|
||||||
|
pub probabilities: Vec<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ForecastResult {
|
||||||
|
pub settlement_point: String,
|
||||||
|
pub hours_ahead: u32,
|
||||||
|
pub predictions: Vec<HourlyForecast>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct HourlyForecast {
|
||||||
|
pub hour: u8,
|
||||||
|
pub price: f64,
|
||||||
|
pub lower_bound: f64,
|
||||||
|
pub upper_bound: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MlInferenceClient {
|
||||||
|
pub fn new(platform: cxai_sdk::CxPlatform) -> Self {
|
||||||
|
Self { platform }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Single model prediction.
|
||||||
|
pub async fn predict(
|
||||||
|
&self,
|
||||||
|
request: &PredictionRequest,
|
||||||
|
) -> cxai_sdk::CxResult<PredictionResult> {
|
||||||
|
debug!(model = %request.model, "ML predict");
|
||||||
|
let raw = self
|
||||||
|
.platform
|
||||||
|
.ml()
|
||||||
|
.predict(&request.model, request.features.clone())
|
||||||
|
.await?;
|
||||||
|
Ok(serde_json::from_value(raw)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect anomalies in feature data.
|
||||||
|
pub async fn detect_anomalies(
|
||||||
|
&self,
|
||||||
|
features: serde_json::Value,
|
||||||
|
) -> cxai_sdk::CxResult<AnomalyResult> {
|
||||||
|
debug!("ML detect anomalies");
|
||||||
|
let raw = self.platform.ml().detect_anomalies(features).await?;
|
||||||
|
Ok(serde_json::from_value(raw)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Classify price regime.
|
||||||
|
pub async fn classify_regime(
|
||||||
|
&self,
|
||||||
|
features: serde_json::Value,
|
||||||
|
) -> cxai_sdk::CxResult<RegimeResult> {
|
||||||
|
debug!("ML classify regime");
|
||||||
|
let raw = self.platform.ml().classify_regime(features).await?;
|
||||||
|
Ok(serde_json::from_value(raw)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Price forecast for a settlement point.
|
||||||
|
pub async fn price_forecast(
|
||||||
|
&self,
|
||||||
|
settlement_point: &str,
|
||||||
|
hours_ahead: u32,
|
||||||
|
) -> cxai_sdk::CxResult<ForecastResult> {
|
||||||
|
debug!(settlement_point, hours_ahead, "ML price forecast");
|
||||||
|
let raw = self
|
||||||
|
.platform
|
||||||
|
.ml()
|
||||||
|
.price_forecast(settlement_point, hours_ahead)
|
||||||
|
.await?;
|
||||||
|
Ok(serde_json::from_value(raw)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
pub mod catalog;
|
||||||
|
pub mod inference;
|
||||||
|
|
||||||
|
pub use catalog::ModelCatalog;
|
||||||
|
pub use inference::MlInferenceClient;
|
||||||
Reference in New Issue
Block a user