use std::collections::BTreeMap;
use std::collections::HashMap;
use std::time::Duration;
use anyhow::anyhow;
use chrono::DateTime;
use chrono::Utc;
use jsonrpsee::types::SubscriptionId;
use serde::Serialize;
use serde::Serializer;
use tokio::select;
use tokio::signal::unix::signal;
use tokio::signal::unix::SignalKind;
use crate::infra::tracing::info_task_spawn;
use crate::log_and_err;
use crate::GlobalState;
#[macro_export]
macro_rules! if_else {
($condition: expr, $_true: expr, $_false: expr) => {
if $condition {
$_true
} else {
$_false
}
};
}
#[inline(always)]
pub fn not(value: bool) -> bool {
!value
}
pub fn type_basename<T>() -> &'static str {
let name: &'static str = std::any::type_name::<T>();
name.rsplit("::").next().unwrap_or(name)
}
#[macro_export]
macro_rules! gen_newtype_from {
(self = $type:ty, other = $($source:ty),+) => {
$(
impl From<$source> for $type {
fn from(value: $source) -> Self {
Self(value.into())
}
}
)+
};
}
#[macro_export]
macro_rules! gen_newtype_try_from {
(self = $type:ty, other = $($source:ty),+) => {
$(
impl TryFrom<$source> for $type {
type Error = anyhow::Error;
fn try_from(value: $source) -> Result<Self, Self::Error> {
Ok(Self(value.try_into().map_err(|err| anyhow::anyhow!("{:?}", err))?))
}
}
)+
};
}
pub trait DisplayExt {
fn to_string_ext(&self) -> String;
}
impl DisplayExt for std::time::Duration {
fn to_string_ext(&self) -> String {
humantime::Duration::from(*self).to_string()
}
}
impl DisplayExt for SubscriptionId<'_> {
fn to_string_ext(&self) -> String {
match self {
SubscriptionId::Num(value) => value.to_string(),
SubscriptionId::Str(value) => value.to_string(),
}
}
}
pub trait OptionExt<T> {
fn map_into<U: From<T>>(self) -> Option<U>;
}
impl<T> OptionExt<T> for Option<T> {
fn map_into<U: From<T>>(self) -> Option<U> {
self.map(Into::into)
}
}
pub trait InfallibleExt<T, E> {
fn expect_infallible(self) -> T;
}
impl<T> InfallibleExt<T, serde_json::Error> for Result<T, serde_json::Error>
where
T: Sized,
{
#[allow(clippy::expect_used)]
fn expect_infallible(self) -> T {
if let Err(ref e) = self {
tracing::error!(reason = ?e, "expected infallible serde serialization/deserialization");
}
self.expect("serde serialization/deserialization")
}
}
impl InfallibleExt<DateTime<Utc>, ()> for Option<DateTime<Utc>> {
#[allow(clippy::expect_used)]
fn expect_infallible(self) -> DateTime<Utc> {
if self.is_none() {
tracing::error!("expected infallible datetime conversion");
}
self.expect("infallible datetime conversion")
}
}
pub fn parse_duration(s: &str) -> anyhow::Result<Duration> {
let millis: Result<u64, _> = s.parse();
if let Ok(millis) = millis {
return Ok(Duration::from_millis(millis));
}
if let Ok(parsed) = humantime::parse_duration(s) {
return Ok(parsed);
}
Err(anyhow!("invalid duration format: {}", s))
}
#[derive(Debug, strum::Display)]
pub enum SleepReason {
#[strum(to_string = "interval")]
Interval,
#[strum(to_string = "retry-backoff")]
RetryBackoff,
#[strum(to_string = "sync-data")]
SyncData,
}
#[cfg(feature = "tracing")]
#[inline(always)]
pub async fn traced_sleep(duration: Duration, reason: SleepReason) {
use tracing::Instrument;
let span = tracing::debug_span!("tokio::sleep", duration_ms = %duration.as_millis(), %reason);
async {
tracing::debug!(duration_ms = %duration.as_millis(), %reason, "sleeping");
tokio::time::sleep(duration).await;
}
.instrument(span)
.await;
}
#[cfg(not(feature = "tracing"))]
#[inline(always)]
pub async fn traced_sleep(duration: Duration, _: SleepReason) {
tokio::time::sleep(duration).await;
}
#[track_caller]
#[allow(clippy::expect_used)]
pub fn spawn_named<T>(name: &str, task: impl std::future::Future<Output = T> + Send + 'static) -> tokio::task::JoinHandle<T>
where
T: Send + 'static,
{
info_task_spawn(name);
tokio::task::Builder::new()
.name(name)
.spawn(task)
.expect("spawning named async task should not fail")
}
#[allow(clippy::expect_used)]
#[track_caller]
pub fn spawn_thread<T>(name: &str, task: impl FnOnce() -> T + Send + 'static) -> std::thread::JoinHandle<T>
where
T: Send + 'static,
{
info_task_spawn(name);
let runtime = tokio::runtime::Handle::current();
std::thread::Builder::new()
.name(name.into())
.spawn(move || {
let _runtime_guard = runtime.enter();
task()
})
.expect("spawning background thread should not fail")
}
pub async fn spawn_signal_handler() -> anyhow::Result<()> {
const TASK_NAME: &str = "signal-handler";
let mut sigterm = match signal(SignalKind::terminate()) {
Ok(signal) => signal,
Err(e) => return log_and_err!(reason = e, "failed to init SIGTERM watcher"),
};
let mut sigint = match signal(SignalKind::interrupt()) {
Ok(signal) => signal,
Err(e) => return log_and_err!(reason = e, "failed to init SIGINT watcher"),
};
spawn_named("sys::signal_handler", async move {
select! {
_ = sigterm.recv() => {
GlobalState::shutdown_from(TASK_NAME, "received SIGTERM");
}
_ = sigint.recv() => {
GlobalState::shutdown_from(TASK_NAME, "received SIGINT");
}
}
});
Ok(())
}
pub fn to_json_string<V: serde::Serialize>(value: &V) -> String {
serde_json::to_string(value).expect_infallible()
}
pub fn to_json_value<V: serde::Serialize>(value: V) -> serde_json::Value {
serde_json::to_value(value).expect_infallible()
}
pub fn from_json_str<T: serde::de::DeserializeOwned>(s: &str) -> T {
serde_json::from_str::<T>(s).expect_infallible()
}
pub fn ordered_map<S, K: Ord + Serialize, V: Serialize>(value: &HashMap<K, V>, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let ordered: BTreeMap<_, _> = value.iter().collect();
ordered.serialize(serializer)
}
#[macro_export]
macro_rules! gen_test_serde {
($type:ty) => {
paste::paste! {
#[test]
pub fn [<serde_debug_json_ $type:snake>]() {
let original = <fake::Faker as fake::Fake>::fake::<$type>(&fake::Faker);
let encoded_json = serde_json::to_string(&original).expect(concat!("failed to serialize in test for ", stringify!($type)));
let encoded_debug = format!("{:?}", original);
assert_eq!(encoded_json, encoded_debug);
}
#[test]
pub fn [<serde_json_ $type:snake>]() {
let original = <fake::Faker as fake::Fake>::fake::<$type>(&fake::Faker);
let encoded = serde_json::to_string(&original).unwrap();
let decoded = serde_json::from_str::<$type>(&encoded).unwrap();
assert_eq!(decoded, original);
let reencoded = serde_json::to_string(&decoded).unwrap();
assert_eq!(reencoded, encoded);
let redecoded = serde_json::from_str::<$type>(&reencoded).unwrap();
assert_eq!(redecoded, original);
}
}
};
}
#[macro_export]
macro_rules! gen_test_json {
($type:ty) => {
paste::paste! {
#[test]
fn [<test_ $type:snake _json_snapshot>]() -> anyhow::Result<()> {
use anyhow::bail;
use std::path::Path;
use std::{env, fs};
let expected: $type = $crate::utils::test_utils::fake_first::<$type>();
let expected_json = serde_json::to_string_pretty(&expected)?;
let snapshot_parent_path = "tests/fixtures/primitives";
let snapshot_name = format!("{}.json", stringify!($type));
let snapshot_path = format!("{}/{}", snapshot_parent_path, snapshot_name);
if !Path::new(&snapshot_path).exists() {
if env::var("DANGEROUS_UPDATE_SNAPSHOTS").is_ok() {
fs::create_dir_all(snapshot_parent_path)?;
fs::write(&snapshot_path, &expected_json)?;
} else {
bail!("snapshot file at '{snapshot_path}' doesn't exist and DANGEROUS_UPDATE_SNAPSHOTS is not set");
}
}
let snapshot_content = fs::read_to_string(&snapshot_path)?;
let deserialized = match serde_json::from_str::<$type>(&snapshot_content) {
Ok(value) => value,
Err(e) => {
bail!("Failed to deserialize snapshot:\nError: {}\n\nExpected JSON:\n{}\n\nActual JSON from snapshot:\n{}",
e, expected_json, snapshot_content);
}
};
assert_eq!(
expected, deserialized,
"\nDeserialized value doesn't match expected:\n\nExpected JSON:\n{}\n\nDeserialized JSON:\n{}",
expected_json,
serde_json::to_string_pretty(&deserialized)?
);
Ok(())
}
}
};
}
#[macro_export]
macro_rules! gen_test_bincode {
($type:ty) => {
paste::paste! {
#[test]
pub fn [<bincode_ $type:snake>]() {
let value = <fake::Faker as fake::Fake>::fake::<$type>(&fake::Faker);
let binary = bincode::serialize(&value).unwrap();
assert_eq!(bincode::deserialize::<$type>(&binary).unwrap(), value);
}
}
};
}