stratus/
ext.rs

1//! Standard library extensions.
2
3use std::collections::BTreeMap;
4use std::collections::HashMap;
5use std::time::Duration;
6
7use alloy_primitives::U64;
8use alloy_primitives::U256;
9use anyhow::anyhow;
10use chrono::DateTime;
11use chrono::Utc;
12use jsonrpsee::types::SubscriptionId;
13use serde::Serialize;
14use serde::Serializer;
15use tokio::select;
16use tokio::signal::unix::SignalKind;
17use tokio::signal::unix::signal;
18use tokio::sync::watch::error::RecvError;
19
20use crate::GlobalState;
21use crate::infra::tracing::info_task_spawn;
22use crate::log_and_err;
23
24// -----------------------------------------------------------------------------
25// Language constructs
26// -----------------------------------------------------------------------------
27
28/// Ternary operator from [ternop](https://docs.rs/ternop/1.0.1/ternop/), but renamed.
29#[macro_export]
30macro_rules! if_else {
31    ($condition: expr, $_true: expr, $_false: expr) => {
32        if $condition { $_true } else { $_false }
33    };
34}
35
36/// `not(something)` instead of `!something`.
37#[inline(always)]
38pub fn not(value: bool) -> bool {
39    !value
40}
41
42/// Extracts only the basename of a Rust type instead of the full qualification.
43pub fn type_basename<T>() -> &'static str {
44    let name: &'static str = std::any::type_name::<T>();
45    name.rsplit("::").next().unwrap_or(name)
46}
47
48// -----------------------------------------------------------------------------
49// From / TryFrom
50// -----------------------------------------------------------------------------
51
52pub trait RuintExt {
53    fn as_u64(&self) -> u64;
54}
55
56impl RuintExt for U256 {
57    fn as_u64(&self) -> u64 {
58        self.as_limbs()[0]
59    }
60}
61
62impl RuintExt for U64 {
63    fn as_u64(&self) -> u64 {
64        self.as_limbs()[0]
65    }
66}
67
68/// Generates [`From`] implementation for a [newtype](https://doc.rust-lang.org/rust-by-example/generics/new_types.html) that delegates to the inner type [`From`].
69#[macro_export]
70macro_rules! gen_newtype_from {
71    (self = $type:ty, other = $($source:ty),+) => {
72        $(
73            impl From<$source> for $type {
74                fn from(value: $source) -> Self {
75                    Self(value.into())
76                }
77            }
78        )+
79    };
80}
81
82/// Generates [`TryFrom`] implementation for a [newtype](https://doc.rust-lang.org/rust-by-example/generics/new_types.html) that delegates to the inner type [`TryFrom`].
83#[macro_export]
84macro_rules! gen_newtype_try_from {
85    (self = $type:ty, other = $($source:ty),+) => {
86        $(
87            impl TryFrom<$source> for $type {
88                type Error = anyhow::Error;
89                fn try_from(value: $source) -> Result<Self, Self::Error> {
90                    Ok(Self(value.try_into().map_err(|err| anyhow::anyhow!("{:?}", err))?))
91                }
92            }
93        )+
94    };
95}
96
97// -----------------------------------------------------------------------------
98// Display
99// -----------------------------------------------------------------------------
100
101/// Allows to implement `to_string` for types that does not have it.
102pub trait DisplayExt {
103    /// `to_string` for types that does not have it implemented.
104    fn to_string_ext(&self) -> String;
105}
106
107impl DisplayExt for std::time::Duration {
108    fn to_string_ext(&self) -> String {
109        humantime::Duration::from(*self).to_string()
110    }
111}
112
113impl DisplayExt for SubscriptionId<'_> {
114    fn to_string_ext(&self) -> String {
115        match self {
116            SubscriptionId::Num(value) => value.to_string(),
117            SubscriptionId::Str(value) => value.to_string(),
118        }
119    }
120}
121
122// -----------------------------------------------------------------------------
123// Option
124// -----------------------------------------------------------------------------
125
126/// Extensions for `Option<T>`.
127pub trait OptionExt<T> {
128    /// Converts the Option inner type to the inferred type.
129    fn map_into<U: From<T>>(self) -> Option<U>;
130}
131
132impl<T> OptionExt<T> for Option<T> {
133    fn map_into<U: From<T>>(self) -> Option<U> {
134        self.map(Into::into)
135    }
136}
137
138// -----------------------------------------------------------------------------
139// Result
140// -----------------------------------------------------------------------------
141
142pub trait InfallibleExt<T, E> {
143    /// Unwraps a result informing that this operation is expected to be infallible.
144    fn expect_infallible(self) -> T;
145}
146
147impl<T> InfallibleExt<T, serde_json::Error> for Result<T, serde_json::Error>
148where
149    T: Sized,
150{
151    #[allow(clippy::expect_used)]
152    fn expect_infallible(self) -> T {
153        if let Err(ref e) = self {
154            tracing::error!(reason = ?e, "expected infallible serde serialization/deserialization");
155        }
156        self.expect("serde serialization/deserialization")
157    }
158}
159
160impl InfallibleExt<DateTime<Utc>, ()> for Option<DateTime<Utc>> {
161    #[allow(clippy::expect_used)]
162    fn expect_infallible(self) -> DateTime<Utc> {
163        if self.is_none() {
164            tracing::error!("expected infallible datetime conversion");
165        }
166        self.expect("infallible datetime conversion")
167    }
168}
169
170// -----------------------------------------------------------------------------
171// Duration
172// -----------------------------------------------------------------------------
173
174/// Parses a duration specified using human-time notation or fallback to milliseconds.
175pub fn parse_duration(s: &str) -> anyhow::Result<Duration> {
176    // try millis
177    let millis: Result<u64, _> = s.parse();
178    if let Ok(millis) = millis {
179        return Ok(Duration::from_millis(millis));
180    }
181
182    // try humantime
183    if let Ok(parsed) = humantime::parse_duration(s) {
184        return Ok(parsed);
185    }
186
187    // error
188    Err(anyhow!("invalid duration format: {}", s))
189}
190
191// -----------------------------------------------------------------------------
192// Tokio
193// -----------------------------------------------------------------------------
194
195/// Indicates why a sleep is happening.
196#[derive(Debug, strum::Display)]
197pub enum SleepReason {
198    /// Task is executed at predefined intervals.
199    #[strum(to_string = "interval")]
200    Interval,
201
202    /// Task is awaiting a backoff before retrying the operation.
203    #[strum(to_string = "retry-backoff")]
204    RetryBackoff,
205
206    /// Task is awaiting an external system or component to produde or synchronize data.
207    #[strum(to_string = "sync-data")]
208    SyncData,
209}
210
211/// Sleeps the current task and tracks why it is sleeping.
212#[cfg(feature = "tracing")]
213#[inline(always)]
214pub async fn traced_sleep(duration: Duration, reason: SleepReason) {
215    use tracing::Instrument;
216
217    let span = tracing::debug_span!("tokio::sleep", duration_ms = %duration.as_millis(), %reason);
218    async {
219        tracing::debug!(duration_ms = %duration.as_millis(), %reason, "sleeping");
220        tokio::time::sleep(duration).await;
221    }
222    .instrument(span)
223    .await;
224}
225
226#[cfg(not(feature = "tracing"))]
227#[inline(always)]
228pub async fn traced_sleep(duration: Duration, _: SleepReason) {
229    tokio::time::sleep(duration).await;
230}
231
232/// Spawns an async Tokio task with a name to be displayed in tokio-console.
233#[track_caller]
234#[allow(clippy::expect_used)]
235pub fn spawn<T>(name: &str, task: impl std::future::Future<Output = T> + Send + 'static) -> tokio::task::JoinHandle<T>
236where
237    T: Send + 'static,
238{
239    info_task_spawn(name);
240    tokio::task::spawn(task)
241}
242
243/// Spawns a thread with the given name. Thread has access to Tokio current runtime.
244#[allow(clippy::expect_used)]
245#[track_caller]
246pub fn spawn_thread<T>(name: &str, task: impl FnOnce() -> T + Send + 'static) -> std::thread::JoinHandle<T>
247where
248    T: Send + 'static,
249{
250    info_task_spawn(name);
251
252    let runtime = tokio::runtime::Handle::current();
253    std::thread::Builder::new()
254        .name(name.into())
255        .spawn(move || {
256            let _runtime_guard = runtime.enter();
257            task()
258        })
259        .expect("spawning background thread should not fail")
260}
261
262/// Spawns a handler that listens to system signals.
263pub async fn spawn_signal_handler() -> anyhow::Result<()> {
264    const TASK_NAME: &str = "signal-handler";
265
266    let mut sigterm = match signal(SignalKind::terminate()) {
267        Ok(signal) => signal,
268        Err(e) => return log_and_err!(reason = e, "failed to init SIGTERM watcher"),
269    };
270    let mut sigint = match signal(SignalKind::interrupt()) {
271        Ok(signal) => signal,
272        Err(e) => return log_and_err!(reason = e, "failed to init SIGINT watcher"),
273    };
274
275    spawn("sys::signal_handler", async move {
276        select! {
277            _ = sigterm.recv() => {
278                GlobalState::shutdown_from(TASK_NAME, "received SIGTERM");
279            }
280            _ = sigint.recv() => {
281                GlobalState::shutdown_from(TASK_NAME, "received SIGINT");
282            }
283        }
284    });
285
286    Ok(())
287}
288
289// -----------------------------------------------------------------------------
290// serde_json
291// -----------------------------------------------------------------------------
292
293/// Serializes any serializable value to non-formatted [`String`] without having to check for errors.
294pub fn to_json_string<V: serde::Serialize>(value: &V) -> String {
295    serde_json::to_string(value).expect_infallible()
296}
297
298/// Serializes any serializable value to [`serde_json::Value`] without having to check for errors.
299pub fn to_json_value<V: serde::Serialize>(value: V) -> serde_json::Value {
300    serde_json::to_value(value).expect_infallible()
301}
302
303/// Deserializes any deserializable value from [`&str`] without having to check for errors.
304pub fn from_json_str<T: serde::de::DeserializeOwned>(s: &str) -> T {
305    serde_json::from_str::<T>(s).expect_infallible()
306}
307
308pub fn ordered_map<S, K: Ord + Serialize, V: Serialize>(value: &HashMap<K, V>, serializer: S) -> Result<S::Ok, S::Error>
309where
310    S: Serializer,
311{
312    let ordered: BTreeMap<_, _> = value.iter().collect();
313    ordered.serialize(serializer)
314}
315
316// -----------------------------------------------------------------------------
317// Tests
318// -----------------------------------------------------------------------------
319
320/// Generates unit test that checks implementation of [`Serialize`](serde::Serialize) and [`Deserialize`](serde::Deserialize) are compatible.
321#[macro_export]
322macro_rules! gen_test_serde {
323    ($type:ty) => {
324        paste::paste! {
325            #[test]
326            pub fn [<serde_debug_json_ $type:snake>]() {
327                let original = <fake::Faker as fake::Fake>::fake::<$type>(&fake::Faker);
328                let encoded_json = serde_json::to_string(&original).expect(concat!("failed to serialize in test for ", stringify!($type)));
329                let encoded_debug = format!("{:?}", original);
330                assert_eq!(encoded_json, encoded_debug);
331            }
332
333            #[test]
334            pub fn [<serde_json_ $type:snake>]() {
335                // encode
336                let original = <fake::Faker as fake::Fake>::fake::<$type>(&fake::Faker);
337                let encoded = serde_json::to_string(&original).unwrap();
338
339                // decode
340                let decoded = serde_json::from_str::<$type>(&encoded).unwrap();
341                assert_eq!(decoded, original);
342
343                // re-encode
344                let reencoded = serde_json::to_string(&decoded).unwrap();
345                assert_eq!(reencoded, encoded);
346
347                // re-decode
348                let redecoded = serde_json::from_str::<$type>(&reencoded).unwrap();
349                assert_eq!(redecoded, original);
350            }
351        }
352    };
353}
354
355/// Generates a unit test that verifies JSON serialization/deserialization compatibility using snapshots.
356#[macro_export]
357macro_rules! gen_test_json {
358    ($type:ty) => {
359        paste::paste! {
360            #[test]
361            fn [<test_ $type:snake _json_snapshot>]() -> anyhow::Result<()> {
362                use anyhow::bail;
363                use std::path::Path;
364                use std::{env, fs};
365
366                let expected: $type = $crate::utils::test_utils::fake_first::<$type>();
367                let expected_json = serde_json::to_string_pretty(&expected)?;
368                let snapshot_parent_path = "tests/fixtures/primitives";
369                let snapshot_name = format!("{}.json", stringify!($type));
370                let snapshot_path = format!("{}/{}", snapshot_parent_path, snapshot_name);
371
372                // WARNING: If you need to update snapshots (DANGEROUS_UPDATE_SNAPSHOTS=1), you have likely
373                // broken backwards compatibility! Make sure this is intentional.
374                if !Path::new(&snapshot_path).exists() {
375                    if env::var("DANGEROUS_UPDATE_SNAPSHOTS").is_ok() {
376                        fs::create_dir_all(snapshot_parent_path)?;
377                        fs::write(&snapshot_path, &expected_json)?;
378                    } else {
379                        bail!("snapshot file at '{snapshot_path}' doesn't exist and DANGEROUS_UPDATE_SNAPSHOTS is not set");
380                    }
381                }
382
383                // Read and attempt to deserialize the snapshot
384                let snapshot_content = fs::read_to_string(&snapshot_path)?;
385                let deserialized = match serde_json::from_str::<$type>(&snapshot_content) {
386                    Ok(value) => value,
387                    Err(e) => {
388                        bail!("Failed to deserialize snapshot:\nError: {}\n\nExpected JSON:\n{}\n\nActual JSON from snapshot:\n{}",
389                            e, expected_json, snapshot_content);
390                    }
391                };
392
393                // Compare the values
394                assert_eq!(
395                    expected, deserialized,
396                    "\nDeserialized value doesn't match expected:\n\nExpected JSON:\n{}\n\nDeserialized JSON:\n{}",
397                    expected_json,
398                    serde_json::to_string_pretty(&deserialized)?
399                );
400
401                Ok(())
402            }
403        }
404    };
405}
406
407/// Generates unit test that checks that bincode's serialization and deserialization are compatible
408#[macro_export]
409macro_rules! gen_test_bincode {
410    ($type:ty) => {
411        paste::paste! {
412            #[test]
413            pub fn [<bincode_ $type:snake>]() {
414                use $crate::rocks_bincode_config;
415                let value = <fake::Faker as fake::Fake>::fake::<$type>(&fake::Faker);
416                let binary = bincode::encode_to_vec(&value, rocks_bincode_config()).unwrap();
417                let (decoded, _): ($type, _) = bincode::decode_from_slice(&binary, rocks_bincode_config()).unwrap();
418                assert_eq!(decoded, value);
419            }
420        }
421    };
422}
423
424/// Custom bincode configuration for RocksDB that preserves lexicographical ordering.
425pub fn rocks_bincode_config() -> impl bincode::config::Config {
426    bincode::config::standard().with_big_endian()
427}
428
429pub trait WatchReceiverExt<T> {
430    #[allow(async_fn_in_trait)]
431    async fn wait_for_change(&mut self, f: impl Fn(&T) -> bool) -> Result<(), RecvError>;
432}
433
434impl<T> WatchReceiverExt<T> for tokio::sync::watch::Receiver<T> {
435    async fn wait_for_change(&mut self, f: impl Fn(&T) -> bool) -> Result<(), RecvError> {
436        loop {
437            self.changed().await?;
438            if f(&self.borrow()) {
439                return Ok(());
440            }
441        }
442    }
443}