1use std::{collections::HashMap, sync::Arc};
75
76use serde::{Deserialize, Serialize};
77use tokio::sync::RwLock;
78
79use crate::{AirError, Result, dev_log};
80
81#[derive(Debug, Clone)]
83pub struct TraceGenerator {
84 trace_spans:Arc<RwLock<HashMap<String, TraceSpan>>>,
85
86 sampling_config:Arc<RwLock<SamplingConfig>>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct SamplingConfig {
92 pub sample_rate:f64,
94
95 pub critical_sample_rate:f64,
97
98 pub max_spans_per_trace:usize,
100
101 pub trace_ttl_ms:u64,
103}
104
105impl Default for SamplingConfig {
106 fn default() -> Self {
107 Self {
108 sample_rate:0.1, critical_sample_rate:1.0, max_spans_per_trace:1000,
111
112 trace_ttl_ms:3600000, }
114 }
115}
116
117impl SamplingConfig {
118 pub fn validate(&self) -> Result<()> {
120 if self.sample_rate < 0.0 || self.sample_rate > 1.0 {
121 return Err(crate::AirError::Internal("sample_rate must be between 0.0 and 1.0".to_string()));
122 }
123
124 if self.critical_sample_rate < 0.0 || self.critical_sample_rate > 1.0 {
125 return Err(crate::AirError::Internal(
126 "critical_sample_rate must be between 0.0 and 1.0".to_string(),
127 ));
128 }
129
130 if self.max_spans_per_trace == 0 {
131 return Err(crate::AirError::Internal(
132 "max_spans_per_trace must be greater than 0".to_string(),
133 ));
134 }
135
136 if self.trace_ttl_ms == 0 {
137 return Err(crate::AirError::Internal("trace_ttl_ms must be greater than 0".to_string()));
138 }
139
140 Ok(())
141 }
142}
143
144#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct TraceSpan {
147 pub span_id:String,
148
149 pub trace_id:String,
150
151 pub parent_span_id:Option<String>,
152
153 pub operation_name:String,
154
155 pub start_time:u64,
156
157 pub end_time:Option<u64>,
158
159 pub status:SpanStatus,
160
161 pub attributes:HashMap<String, String>,
162
163 pub events:Vec<SpanEvent>,
164
165 pub error:Option<String>,
166
167 pub duration_ms:Option<u64>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
172pub enum SpanStatus {
173 Started,
174
175 Active,
176
177 Completed,
178
179 Failed,
180
181 Cancelled,
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct SpanEvent {
187 pub timestamp:u64,
188
189 pub name:String,
190
191 pub attributes:HashMap<String, String>,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct TraceMetadata {
197 pub trace_id:String,
198
199 pub root_span_id:String,
200
201 pub total_spans:usize,
202
203 pub root_operation:String,
204
205 pub start_time:u64,
206
207 pub end_time:Option<u64>,
208
209 pub total_duration_ms:Option<u64>,
210
211 pub status:TraceStatus,
212}
213
214#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
216pub enum TraceStatus {
217 InProgress,
218
219 Completed,
220
221 Failed,
222
223 Cancelled,
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct PropagationContext {
229 pub TraceId:String,
230
231 pub SpanId:String,
232
233 pub CorrelationId:String,
234
235 pub ParentSpanId:Option<String>,
236}
237
238impl TraceGenerator {
239 pub fn new() -> Self {
241 Self {
242 trace_spans:Arc::new(RwLock::new(HashMap::new())),
243
244 sampling_config:Arc::new(RwLock::new(SamplingConfig::default())),
245 }
246 }
247
248 pub fn with_sampling(sampling_config:SamplingConfig) -> Result<Self> {
250 sampling_config
251 .validate()
252 .map_err(|e| AirError::Internal(format!("Invalid sampling config: {}", e)))?;
253
254 Ok(Self {
255 trace_spans:Arc::new(RwLock::new(HashMap::new())),
256 sampling_config:Arc::new(RwLock::new(sampling_config)),
257 })
258 }
259
260 pub fn generate_trace_id() -> String {
262 std::panic::catch_unwind(|| uuid::Uuid::new_v4().to_string()).unwrap_or_else(|e| {
263 dev_log!("air", "error: [Tracing] Panic in generate_trace_id, using fallback: {:?}", e);
264 format!("{:x}", rand::random::<u64>())
265 })
266 }
267
268 pub fn generate_span_id() -> String {
270 std::panic::catch_unwind(|| uuid::Uuid::new_v4().to_string()).unwrap_or_else(|e| {
271 dev_log!("air", "error: [Tracing] Panic in generate_span_id, using fallback: {:?}", e);
272 format!("{:x}", rand::random::<u64>())
273 })
274 }
275
276 pub async fn should_sample(&self, is_critical:bool) -> bool {
278 let config = self.sampling_config.read().await;
279
280 let rate = if is_critical { config.critical_sample_rate } else { config.sample_rate };
281
282 rand::random::<f64>() < rate
283 }
284
285 pub fn parse_trace_context(header:&str) -> Result<PropagationContext> {
287 let parts:Vec<&str> = header.split(';').collect();
288
289 let mut trace_id = String::new();
290
291 let mut parent_span_id = None;
292
293 for part in parts {
294 let kv:Vec<&str> = part.split('=').collect();
295
296 if kv.len() != 2 {
297 continue;
298 }
299
300 match kv[0].trim() {
301 "traceparent" => {
302 let trace_parent:Vec<&str> = kv[1].trim().split('-').collect();
303
304 if trace_parent.len() >= 2 {
305 trace_id = trace_parent[1].to_string();
306
307 if trace_parent.len() >= 3 {
308 parent_span_id = Some(trace_parent[2].to_string());
309 }
310 }
311 },
312
313 _ => {},
314 }
315 }
316
317 if trace_id.is_empty() {
318 return Err(AirError::Internal("Invalid trace context header".to_string()));
319 }
320
321 Ok(PropagationContext {
322 TraceId:trace_id,
323 SpanId:Self::generate_span_id(),
324 CorrelationId:crate::Utility::GenerateRequestId(),
325 ParentSpanId:parent_span_id,
326 })
327 }
328
329 pub async fn create_span(
331 &self,
332
333 trace_id:String,
334
335 operation_name:impl Into<String>,
336
337 parent_span_id:Option<String>,
338
339 attributes:Option<HashMap<String, String>>,
340 ) -> Result<TraceSpan> {
341 let span_id = Self::generate_span_id();
342
343 let operation_name = operation_name.into();
344
345 let span = TraceSpan {
346 span_id:span_id.clone(),
347
348 trace_id:trace_id.clone(),
349
350 parent_span_id:parent_span_id.clone(),
351
352 operation_name:operation_name.clone(),
353
354 start_time:crate::Utility::CurrentTimestamp(),
355
356 end_time:None,
357
358 status:SpanStatus::Started,
359
360 attributes:attributes.unwrap_or_default(),
361
362 events:Vec::new(),
363
364 error:None,
365
366 duration_ms:None,
367 };
368
369 let mut spans = self.trace_spans.write().await;
370
371 let trace_span_count = spans.values().filter(|s| s.trace_id == trace_id).count();
373
374 let config = self.sampling_config.read().await;
375
376 if trace_span_count >= config.max_spans_per_trace {
377 dev_log!(
378 "air",
379 "warn: [Tracing] Trace {} exceeds max spans ({}), dropping span {}",
380 trace_id,
381 config.max_spans_per_trace,
382 span_id
383 );
384
385 return Err(AirError::Internal("Max spans per trace exceeded".to_string()));
386 }
387
388 spans.insert(span_id.clone(), span.clone());
389
390 Ok(span)
391 }
392
393 pub async fn add_span_event(
395 &self,
396
397 span_id:&str,
398
399 event_name:impl Into<String>,
400
401 attributes:HashMap<String, String>,
402 ) -> Result<()> {
403 let event = SpanEvent {
404 timestamp:crate::Utility::CurrentTimestamp(),
405
406 name:event_name.into(),
407
408 attributes:self.sanitize_attributes(attributes),
409 };
410
411 let mut spans = self.trace_spans.write().await;
412
413 if let Some(span) = spans.get_mut(span_id) {
414 span.events.push(event);
415
416 Ok(())
417 } else {
418 Err(AirError::Internal(format!("Span not found: {}", span_id)))
419 }
420 }
421
422 pub async fn mark_span_active(&self, span_id:&str) -> Result<()> {
424 let mut spans = self.trace_spans.write().await;
425
426 if let Some(span) = spans.get_mut(span_id) {
427 span.status = SpanStatus::Active;
428
429 Ok(())
430 } else {
431 Err(AirError::Internal(format!("Span not found: {}", span_id)))
432 }
433 }
434
435 pub async fn complete_span(&self, span_id:&str, error:Option<String>) -> Result<u64> {
437 let Now = crate::Utility::CurrentTimestamp();
438
439 let mut spans = self.trace_spans.write().await;
440
441 if let Some(span) = spans.get_mut(span_id) {
442 span.end_time = Some(Now);
443
444 span.duration_ms = Some(Now.saturating_sub(span.start_time));
445
446 span.status = if error.is_some() { SpanStatus::Failed } else { SpanStatus::Completed };
447
448 span.error = error.map(|e| self.sanitize_error_message(&e));
449
450 Ok(span.duration_ms.unwrap_or(0))
451 } else {
452 Err(AirError::Internal(format!("Span not found: {}", span_id)))
453 }
454 }
455
456 pub async fn add_span_attribute(&self, span_id:&str, key:String, value:String) -> Result<()> {
458 self.add_span_attributes(span_id, HashMap::from([(key, value)])).await
459 }
460
461 pub async fn add_span_attributes(&self, span_id:&str, attributes:HashMap<String, String>) -> Result<()> {
463 let sanitized = self.sanitize_attributes(attributes);
464
465 let mut spans = self.trace_spans.write().await;
466
467 if let Some(span) = spans.get_mut(span_id) {
468 for (key, value) in sanitized {
469 span.attributes.insert(key, value);
470 }
471
472 Ok(())
473 } else {
474 Err(AirError::Internal(format!("Span not found: {}", span_id)))
475 }
476 }
477
478 pub async fn get_span(&self, span_id:&str) -> Result<TraceSpan> {
480 let spans = self.trace_spans.read().await;
481
482 spans
483 .get(span_id)
484 .cloned()
485 .ok_or_else(|| AirError::Internal(format!("Span not found: {}", span_id)))
486 }
487
488 pub async fn get_trace_spans(&self, trace_id:&str) -> Result<Vec<TraceSpan>> {
490 let spans = self.trace_spans.read().await;
491
492 Ok(spans.values().filter(|span| span.trace_id == trace_id).cloned().collect())
493 }
494
495 pub async fn get_trace_metadata(&self, trace_id:&str) -> Result<TraceMetadata> {
497 let trace_spans = self.get_trace_spans(trace_id).await?;
498
499 if trace_spans.is_empty() {
500 return Err(AirError::Internal(format!("Trace not found: {}", trace_id)));
501 }
502
503 let root_span = trace_spans
504 .iter()
505 .find(|s| s.parent_span_id.is_none())
506 .ok_or_else(|| AirError::Internal("No root span found".to_string()))?;
507
508 let total_duration_ms = trace_spans.iter().filter_map(|s| s.duration_ms).max();
509
510 let status = if trace_spans.iter().any(|s| s.status == SpanStatus::Failed) {
511 TraceStatus::Failed
512 } else if trace_spans
513 .iter()
514 .all(|s| s.status == SpanStatus::Completed || s.status == SpanStatus::Failed)
515 {
516 TraceStatus::Completed
517 } else {
518 TraceStatus::InProgress
519 };
520
521 let end_time = trace_spans.iter().filter_map(|s| s.end_time).max();
522
523 Ok(TraceMetadata {
524 trace_id:trace_id.to_string(),
525 root_span_id:root_span.span_id.clone(),
526 total_spans:trace_spans.len(),
527 root_operation:root_span.operation_name.clone(),
528 start_time:root_span.start_time,
529 end_time,
530 total_duration_ms,
531 status,
532 })
533 }
534
535 pub async fn export_trace(&self, trace_id:&str) -> Result<String> {
537 let spans = self.get_trace_spans(trace_id).await?;
538
539 let metadata = self.get_trace_metadata(trace_id).await?;
540
541 let export = serde_json::json!({
542 "metadata": metadata,
543 "spans": spans,
544 });
545
546 serde_json::to_string_pretty(&export)
547 .map_err(|e| AirError::Serialization(format!("Failed to export trace: {}", e)))
548 }
549
550 pub async fn cleanup_old_spans(&self, older_than_ms:Option<u64>) -> Result<usize> {
552 let Now = crate::Utility::CurrentTimestamp();
553
554 let ttl = older_than_ms.unwrap_or_else(|| {
555 tokio::task::block_in_place(|| {
556 tokio::runtime::Handle::current().block_on(async { self.sampling_config.read().await.trace_ttl_ms })
557 })
558 });
559
560 let mut spans = self.trace_spans.write().await;
561
562 let original_len = spans.len();
563
564 spans.retain(|_, span| span.end_time.map_or(true, |end| Now.saturating_sub(end) < ttl));
565
566 Ok(original_len.saturating_sub(spans.len()))
567 }
568
569 pub async fn get_statistics(&self) -> TraceStatistics {
571 let spans = self.trace_spans.read().await;
572
573 let total_traces = spans
574 .values()
575 .map(|s| s.trace_id.clone())
576 .collect::<std::collections::HashSet<_>>()
577 .len();
578
579 let completed_spans = spans.values().filter(|s| s.status == SpanStatus::Completed).count();
580
581 let failed_spans = spans.values().filter(|s| s.status == SpanStatus::Failed).count();
582
583 let in_progress_spans = spans
584 .values()
585 .filter(|s| s.status == SpanStatus::Started || s.status == SpanStatus::Active)
586 .count();
587
588 TraceStatistics {
589 total_traces:total_traces as u64,
590
591 total_spans:spans.len() as u64,
592
593 completed_spans:completed_spans as u64,
594
595 failed_spans:failed_spans as u64,
596
597 in_progress_spans:in_progress_spans as u64,
598 }
599 }
600
601 fn sanitize_attributes(&self, mut attributes:HashMap<String, String>) -> HashMap<String, String> {
603 let sensitive_keys = vec![
604 "password",
605 "token",
606 "secret",
607 "api_key",
608 "authorization",
609 "credential",
610 "auth",
611 "private_key",
612 "session_token",
613 ];
614
615 let attr_keys:Vec<String> = attributes.keys().cloned().collect();
617
618 for key in sensitive_keys {
619 let key_lower = key.to_lowercase();
620
621 for attr_key in &attr_keys {
622 if attr_key.to_lowercase().contains(&key_lower) {
623 attributes.insert(attr_key.clone(), "[REDACTED]".to_string());
624 }
625 }
626 }
627
628 attributes
629 }
630
631 fn sanitize_error_message(&self, message:&str) -> String {
633 let mut sanitized = message.to_string();
634
635 let patterns = vec![
636 (r"(?i)password[=:]\S+", "password=[REDACTED]"),
637 (r"(?i)token[=:]\S+", "token=[REDACTED]"),
638 (r"(?i)secret[=:]\S+", "secret=[REDACTED]"),
639 (r"(?i)(api|private)[_-]?key[=:]\S+", "api_key=[REDACTED]"),
640 (
641 r"(?i)authorization[=[:space:]]+Bearer[[:space:]]+\S+",
642 "Authorization: Bearer [REDACTED]",
643 ),
644 ];
645
646 for (pattern, replacement) in patterns {
647 if let Ok(re) = regex::Regex::new(pattern) {
648 sanitized = re.replace_all(&sanitized, replacement).to_string();
649 }
650 }
651
652 sanitized
653 }
654}
655
656#[derive(Debug, Clone, Serialize, Deserialize)]
658pub struct TraceStatistics {
659 pub total_traces:u64,
660
661 pub total_spans:u64,
662
663 pub completed_spans:u64,
664
665 pub failed_spans:u64,
666
667 pub in_progress_spans:u64,
668}
669
670impl Default for TraceGenerator {
671 fn default() -> Self { Self::new() }
672}
673
674static TRACE_GENERATOR:std::sync::OnceLock<TraceGenerator> = std::sync::OnceLock::new();
676
677pub fn get_trace_generator() -> &'static TraceGenerator { TRACE_GENERATOR.get_or_init(TraceGenerator::new) }
679
680pub fn initialize_tracing(sampling_config:Option<SamplingConfig>) -> Result<()> {
682 let generator = if let Some(config) = sampling_config {
683 TraceGenerator::with_sampling(config)?
684 } else {
685 TraceGenerator::new()
686 };
687
688 let _old = TRACE_GENERATOR.set(generator);
689
690 dev_log!("air", "[Tracing] Trace generator initialized with tracing");
691
692 Ok(())
693}
694
695thread_local! {
696
697 static PROPAGATION_CONTEXT: std::cell::RefCell<Option<PropagationContext>> = std::cell::RefCell::new(None);
698}
699
700pub fn set_propagation_context(context:PropagationContext) {
702 PROPAGATION_CONTEXT.with(|ctx| {
703 *ctx.borrow_mut() = Some(context);
704 });
705}
706
707pub fn get_propagation_context() -> Option<PropagationContext> { PROPAGATION_CONTEXT.with(|ctx| ctx.borrow().clone()) }
709
710pub async fn create_propagation_context(TraceId:String, ParentSpanId:Option<String>) -> PropagationContext {
712 let SpanId = TraceGenerator::generate_span_id();
713
714 let CorrelationId = crate::Utility::GenerateRequestId();
715
716 PropagationContext { TraceId, SpanId, CorrelationId, ParentSpanId }
717}
718
719pub fn create_trace_context_header(context:&PropagationContext) -> String {
721 format!("traceparent=00-{}-{}-01", context.TraceId, context.SpanId)
722}