hydrant/control/
pds.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use miette::{IntoDiagnostic, Result};
5use serde::Serialize;
6use smol_str::SmolStr;
7
8use crate::config::RateTier;
9use crate::db::pds_meta as db_pds;
10use crate::pds_meta::PdsMeta;
11use crate::state::AppState;
12
13/// a single PDS-to-tier assignment.
14#[derive(Debug, Clone, Serialize)]
15pub struct PdsTierAssignment {
16    pub host: String,
17    pub tier: String,
18}
19
20/// a rate tier definition, as returned by the API.
21#[derive(Debug, Clone, Serialize)]
22pub struct PdsTierDefinition {
23    pub per_second_base: u64,
24    pub per_second_account_mul: f64,
25    pub per_hour: u64,
26    pub per_day: u64,
27}
28
29impl From<RateTier> for PdsTierDefinition {
30    fn from(t: RateTier) -> Self {
31        Self {
32            per_second_base: t.per_second_base,
33            per_second_account_mul: t.per_second_account_mul,
34            per_hour: t.per_hour,
35            per_day: t.per_day,
36        }
37    }
38}
39
40/// runtime control over pds related behaviour (eg. ratelimits).
41#[derive(Clone)]
42pub struct PdsControl(pub(super) Arc<AppState>);
43
44impl PdsControl {
45    async fn update<F, G>(&self, db_op: F, mem_op: G) -> Result<()>
46    where
47        F: FnOnce(&mut fjall::OwnedWriteBatch, &fjall::Keyspace) + Send + 'static,
48        G: FnOnce(&mut PdsMeta),
49    {
50        let state = self.0.clone();
51        tokio::task::spawn_blocking(move || {
52            let mut batch = state.db.inner.batch();
53            db_op(&mut batch, &state.db.filter);
54            batch.commit().into_diagnostic()?;
55            state.db.persist()
56        })
57        .await
58        .into_diagnostic()??;
59
60        let mut snapshot = (**self.0.pds_meta.load()).clone();
61        mem_op(&mut snapshot);
62        self.0.pds_meta.store(Arc::new(snapshot));
63
64        Ok(())
65    }
66
67    /// list all current per-PDS tier assignments.
68    pub async fn list_tiers(&self) -> HashMap<String, String> {
69        let snapshot = self.0.pds_meta.load();
70        snapshot
71            .tiers
72            .iter()
73            .map(|(host, tier)| (host.clone(), tier.to_string()))
74            .collect()
75    }
76
77    /// returns the assigned tier for `host`, or "default" if none is assigned.
78    pub fn get_tier(&self, host: impl AsRef<str>) -> String {
79        let snapshot = self.0.pds_meta.load();
80        snapshot
81            .tiers
82            .get(host.as_ref())
83            .map(|t| t.to_string())
84            .unwrap_or_else(|| "default".to_string())
85    }
86
87    /// returns true if `host` is currently banned.
88    pub fn is_banned(&self, host: impl AsRef<str>) -> bool {
89        self.0.pds_meta.load().is_banned(host.as_ref())
90    }
91
92    /// list all currently banned PDS hosts.
93    pub async fn list_banned(&self) -> Vec<String> {
94        let snapshot = self.0.pds_meta.load();
95        snapshot.banned.iter().cloned().collect()
96    }
97
98    /// list all configured rate tier definitions.
99    pub fn list_rate_tiers(&self) -> HashMap<String, PdsTierDefinition> {
100        self.0
101            .rate_tiers
102            .iter()
103            .map(|(name, tier)| (name.clone(), PdsTierDefinition::from(*tier)))
104            .collect()
105    }
106
107    /// assign `host` to `tier`, persisting the change to the database.
108    /// returns an error if `tier` is not a known tier name.
109    pub async fn set_tier(&self, host: impl AsRef<str>, tier: String) -> Result<()> {
110        if !self.0.rate_tiers.contains_key(&tier) {
111            miette::bail!(
112                "unknown tier '{tier}'; known tiers: {:?}",
113                self.0.rate_tiers.keys().collect::<Vec<_>>()
114            );
115        }
116
117        let host = host.as_ref().to_string();
118        let host_clone = host.clone();
119        let tier_clone = tier.clone();
120        self.update(
121            move |batch, ks| db_pds::set_tier(batch, ks, &host_clone, &tier_clone),
122            move |meta| {
123                meta.tiers.insert(host, SmolStr::new(&tier));
124            },
125        )
126        .await
127    }
128
129    /// remove any explicit tier assignment for `host`, reverting it to the default tier.
130    pub async fn remove_tier(&self, host: impl AsRef<str>) -> Result<()> {
131        let host = host.as_ref().to_string();
132        let host_clone = host.clone();
133        self.update(
134            move |batch, ks| db_pds::remove_tier(batch, ks, &host_clone),
135            move |meta| {
136                meta.tiers.remove(&host);
137            },
138        )
139        .await
140    }
141
142    /// ban `host`, persisting the change to the database.
143    pub async fn ban(&self, host: impl AsRef<str>) -> Result<()> {
144        let host = host.as_ref().to_string();
145        let host_clone = host.clone();
146        self.update(
147            move |batch, ks| db_pds::set_banned(batch, ks, &host_clone),
148            move |meta| {
149                meta.banned.insert(host);
150            },
151        )
152        .await
153    }
154
155    /// unban `host`, removing it from the database.
156    pub async fn unban(&self, host: impl AsRef<str>) -> Result<()> {
157        let host = host.as_ref().to_string();
158        let host_clone = host.clone();
159        self.update(
160            move |batch, ks| db_pds::remove_banned(batch, ks, &host_clone),
161            move |meta| {
162                meta.banned.remove(&host);
163            },
164        )
165        .await
166    }
167}