1use std::collections::HashMap;
2use std::sync::Arc;
3
4use miette::{IntoDiagnostic, Result};
5use serde::Serialize;
6use smol_str::SmolStr;
7
8use crate::config::RateTier;
9use crate::db::pds_meta as db_pds;
10use crate::pds_meta::PdsMeta;
11use crate::state::AppState;
12
13#[derive(Debug, Clone, Serialize)]
15pub struct PdsTierAssignment {
16 pub host: String,
17 pub tier: String,
18}
19
20#[derive(Debug, Clone, Serialize)]
22pub struct PdsTierDefinition {
23 pub per_second_base: u64,
24 pub per_second_account_mul: f64,
25 pub per_hour: u64,
26 pub per_day: u64,
27}
28
29impl From<RateTier> for PdsTierDefinition {
30 fn from(t: RateTier) -> Self {
31 Self {
32 per_second_base: t.per_second_base,
33 per_second_account_mul: t.per_second_account_mul,
34 per_hour: t.per_hour,
35 per_day: t.per_day,
36 }
37 }
38}
39
40#[derive(Clone)]
42pub struct PdsControl(pub(super) Arc<AppState>);
43
44impl PdsControl {
45 async fn update<F, G>(&self, db_op: F, mem_op: G) -> Result<()>
46 where
47 F: FnOnce(&mut fjall::OwnedWriteBatch, &fjall::Keyspace) + Send + 'static,
48 G: FnOnce(&mut PdsMeta),
49 {
50 let state = self.0.clone();
51 tokio::task::spawn_blocking(move || {
52 let mut batch = state.db.inner.batch();
53 db_op(&mut batch, &state.db.filter);
54 batch.commit().into_diagnostic()?;
55 state.db.persist()
56 })
57 .await
58 .into_diagnostic()??;
59
60 let mut snapshot = (**self.0.pds_meta.load()).clone();
61 mem_op(&mut snapshot);
62 self.0.pds_meta.store(Arc::new(snapshot));
63
64 Ok(())
65 }
66
67 pub async fn list_tiers(&self) -> HashMap<String, String> {
69 let snapshot = self.0.pds_meta.load();
70 snapshot
71 .tiers
72 .iter()
73 .map(|(host, tier)| (host.clone(), tier.to_string()))
74 .collect()
75 }
76
77 pub fn get_tier(&self, host: impl AsRef<str>) -> String {
79 let snapshot = self.0.pds_meta.load();
80 snapshot
81 .tiers
82 .get(host.as_ref())
83 .map(|t| t.to_string())
84 .unwrap_or_else(|| "default".to_string())
85 }
86
87 pub fn is_banned(&self, host: impl AsRef<str>) -> bool {
89 self.0.pds_meta.load().is_banned(host.as_ref())
90 }
91
92 pub async fn list_banned(&self) -> Vec<String> {
94 let snapshot = self.0.pds_meta.load();
95 snapshot.banned.iter().cloned().collect()
96 }
97
98 pub fn list_rate_tiers(&self) -> HashMap<String, PdsTierDefinition> {
100 self.0
101 .rate_tiers
102 .iter()
103 .map(|(name, tier)| (name.clone(), PdsTierDefinition::from(*tier)))
104 .collect()
105 }
106
107 pub async fn set_tier(&self, host: impl AsRef<str>, tier: String) -> Result<()> {
110 if !self.0.rate_tiers.contains_key(&tier) {
111 miette::bail!(
112 "unknown tier '{tier}'; known tiers: {:?}",
113 self.0.rate_tiers.keys().collect::<Vec<_>>()
114 );
115 }
116
117 let host = host.as_ref().to_string();
118 let host_clone = host.clone();
119 let tier_clone = tier.clone();
120 self.update(
121 move |batch, ks| db_pds::set_tier(batch, ks, &host_clone, &tier_clone),
122 move |meta| {
123 meta.tiers.insert(host, SmolStr::new(&tier));
124 },
125 )
126 .await
127 }
128
129 pub async fn remove_tier(&self, host: impl AsRef<str>) -> Result<()> {
131 let host = host.as_ref().to_string();
132 let host_clone = host.clone();
133 self.update(
134 move |batch, ks| db_pds::remove_tier(batch, ks, &host_clone),
135 move |meta| {
136 meta.tiers.remove(&host);
137 },
138 )
139 .await
140 }
141
142 pub async fn ban(&self, host: impl AsRef<str>) -> Result<()> {
144 let host = host.as_ref().to_string();
145 let host_clone = host.clone();
146 self.update(
147 move |batch, ks| db_pds::set_banned(batch, ks, &host_clone),
148 move |meta| {
149 meta.banned.insert(host);
150 },
151 )
152 .await
153 }
154
155 pub async fn unban(&self, host: impl AsRef<str>) -> Result<()> {
157 let host = host.as_ref().to_string();
158 let host_clone = host.clone();
159 self.update(
160 move |batch, ks| db_pds::remove_banned(batch, ks, &host_clone),
161 move |meta| {
162 meta.banned.remove(&host);
163 },
164 )
165 .await
166 }
167}