use crate::{NodeCodec, StorageProof};
use codec::Encode;
use hash_db::Hasher;
use parking_lot::Mutex;
use std::{
collections::HashMap,
marker::PhantomData,
mem,
ops::DerefMut,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use trie_db::{RecordedForKey, TrieAccess};
const LOG_TARGET: &str = "trie-recorder";
struct RecorderInner<H> {
recorded_keys: HashMap<H, HashMap<Vec<u8>, RecordedForKey>>,
accessed_nodes: HashMap<H, Vec<u8>>,
}
impl<H> Default for RecorderInner<H> {
fn default() -> Self {
Self { recorded_keys: Default::default(), accessed_nodes: Default::default() }
}
}
pub struct Recorder<H: Hasher> {
inner: Arc<Mutex<RecorderInner<H::Out>>>,
encoded_size_estimation: Arc<AtomicUsize>,
}
impl<H: Hasher> Default for Recorder<H> {
fn default() -> Self {
Self { inner: Default::default(), encoded_size_estimation: Arc::new(0.into()) }
}
}
impl<H: Hasher> Clone for Recorder<H> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
encoded_size_estimation: self.encoded_size_estimation.clone(),
}
}
}
impl<H: Hasher> Recorder<H> {
pub fn as_trie_recorder(
&self,
storage_root: H::Out,
) -> impl trie_db::TrieRecorder<H::Out> + '_ {
TrieRecorder::<H, _> {
inner: self.inner.lock(),
storage_root,
encoded_size_estimation: self.encoded_size_estimation.clone(),
_phantom: PhantomData,
}
}
pub fn drain_storage_proof(self) -> StorageProof {
let mut recorder = mem::take(&mut *self.inner.lock());
StorageProof::new(recorder.accessed_nodes.drain().map(|(_, v)| v))
}
pub fn to_storage_proof(&self) -> StorageProof {
let recorder = self.inner.lock();
StorageProof::new(recorder.accessed_nodes.values().cloned())
}
pub fn estimate_encoded_size(&self) -> usize {
self.encoded_size_estimation.load(Ordering::Relaxed)
}
pub fn reset(&self) {
mem::take(&mut *self.inner.lock());
self.encoded_size_estimation.store(0, Ordering::Relaxed);
}
}
struct TrieRecorder<H: Hasher, I> {
inner: I,
storage_root: H::Out,
encoded_size_estimation: Arc<AtomicUsize>,
_phantom: PhantomData<H>,
}
impl<H: Hasher, I: DerefMut<Target = RecorderInner<H::Out>>> trie_db::TrieRecorder<H::Out>
for TrieRecorder<H, I>
{
fn record<'b>(&mut self, access: TrieAccess<'b, H::Out>) {
let mut encoded_size_update = 0;
match access {
TrieAccess::NodeOwned { hash, node_owned } => {
tracing::trace!(
target: LOG_TARGET,
hash = ?hash,
"Recording node",
);
self.inner.accessed_nodes.entry(hash).or_insert_with(|| {
let node = node_owned.to_encoded::<NodeCodec<H>>();
encoded_size_update += node.encoded_size();
node
});
},
TrieAccess::EncodedNode { hash, encoded_node } => {
tracing::trace!(
target: LOG_TARGET,
hash = ?hash,
"Recording node",
);
self.inner.accessed_nodes.entry(hash).or_insert_with(|| {
let node = encoded_node.into_owned();
encoded_size_update += node.encoded_size();
node
});
},
TrieAccess::Value { hash, value, full_key } => {
tracing::trace!(
target: LOG_TARGET,
hash = ?hash,
key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
"Recording value",
);
self.inner.accessed_nodes.entry(hash).or_insert_with(|| {
let value = value.into_owned();
encoded_size_update += value.encoded_size();
value
});
self.inner
.recorded_keys
.entry(self.storage_root)
.or_default()
.entry(full_key.to_vec())
.and_modify(|e| *e = RecordedForKey::Value)
.or_insert(RecordedForKey::Value);
},
TrieAccess::Hash { full_key } => {
tracing::trace!(
target: LOG_TARGET,
key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
"Recorded hash access for key",
);
self.inner
.recorded_keys
.entry(self.storage_root)
.or_default()
.entry(full_key.to_vec())
.or_insert(RecordedForKey::Hash);
},
TrieAccess::NonExisting { full_key } => {
tracing::trace!(
target: LOG_TARGET,
key = ?sp_core::hexdisplay::HexDisplay::from(&full_key),
"Recorded non-existing value access for key",
);
self.inner
.recorded_keys
.entry(self.storage_root)
.or_default()
.entry(full_key.to_vec())
.and_modify(|e| *e = RecordedForKey::Value)
.or_insert(RecordedForKey::Value);
},
};
self.encoded_size_estimation.fetch_add(encoded_size_update, Ordering::Relaxed);
}
fn trie_nodes_recorded_for_key(&self, key: &[u8]) -> RecordedForKey {
self.inner
.recorded_keys
.get(&self.storage_root)
.and_then(|k| k.get(key).copied())
.unwrap_or(RecordedForKey::None)
}
}
#[cfg(test)]
mod tests {
use trie_db::{Trie, TrieDBBuilder, TrieDBMutBuilder, TrieHash, TrieMut};
type MemoryDB = crate::MemoryDB<sp_core::Blake2Hasher>;
type Layout = crate::LayoutV1<sp_core::Blake2Hasher>;
type Recorder = super::Recorder<sp_core::Blake2Hasher>;
const TEST_DATA: &[(&[u8], &[u8])] =
&[(b"key1", b"val1"), (b"key2", b"val2"), (b"key3", b"val3"), (b"key4", b"val4")];
fn create_trie() -> (MemoryDB, TrieHash<Layout>) {
let mut db = MemoryDB::default();
let mut root = Default::default();
{
let mut trie = TrieDBMutBuilder::<Layout>::new(&mut db, &mut root).build();
for (k, v) in TEST_DATA {
trie.insert(k, v).expect("Inserts data");
}
}
(db, root)
}
#[test]
fn recorder_works() {
let (db, root) = create_trie();
let recorder = Recorder::default();
{
let mut trie_recorder = recorder.as_trie_recorder(root);
let trie = TrieDBBuilder::<Layout>::new(&db, &root)
.with_recorder(&mut trie_recorder)
.build();
assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
}
let storage_proof = recorder.drain_storage_proof();
let memory_db: MemoryDB = storage_proof.into_memory_db();
let trie = TrieDBBuilder::<Layout>::new(&memory_db, &root).build();
assert_eq!(TEST_DATA[0].1.to_vec(), trie.get(TEST_DATA[0].0).unwrap().unwrap());
}
}