1use crate::raw_statement::RawStatement;
4use crate::{Connection, PrepFlags, Result, Statement};
5use hashlink::LruCache;
6use std::cell::RefCell;
7use std::ops::{Deref, DerefMut};
8use std::sync::Arc;
9
10impl Connection {
11 #[inline]
38 pub fn prepare_cached(&self, sql: &str) -> Result<CachedStatement<'_>> {
39 self.cache.get(self, sql)
40 }
41
42 #[inline]
48 pub fn set_prepared_statement_cache_capacity(&self, capacity: usize) {
49 self.cache.set_capacity(capacity);
50 }
51
52 #[inline]
54 pub fn flush_prepared_statement_cache(&self) {
55 self.cache.flush();
56 }
57}
58
59#[derive(Debug)]
61pub struct StatementCache(RefCell<LruCache<Arc<str>, RawStatement>>);
62
63unsafe impl Send for StatementCache {}
64
65pub struct CachedStatement<'conn> {
71 stmt: Option<Statement<'conn>>,
72 cache: &'conn StatementCache,
73}
74
75impl<'conn> Deref for CachedStatement<'conn> {
76 type Target = Statement<'conn>;
77
78 #[inline]
79 fn deref(&self) -> &Statement<'conn> {
80 self.stmt.as_ref().unwrap()
81 }
82}
83
84impl<'conn> DerefMut for CachedStatement<'conn> {
85 #[inline]
86 fn deref_mut(&mut self) -> &mut Statement<'conn> {
87 self.stmt.as_mut().unwrap()
88 }
89}
90
91impl Drop for CachedStatement<'_> {
92 #[inline]
93 fn drop(&mut self) {
94 if let Some(stmt) = self.stmt.take() {
95 self.cache.cache_stmt(unsafe { stmt.into_raw() });
96 }
97 }
98}
99
100impl CachedStatement<'_> {
101 #[inline]
102 fn new<'conn>(stmt: Statement<'conn>, cache: &'conn StatementCache) -> CachedStatement<'conn> {
103 CachedStatement {
104 stmt: Some(stmt),
105 cache,
106 }
107 }
108
109 #[inline]
112 pub fn discard(mut self) {
113 self.stmt = None;
114 }
115}
116
117impl StatementCache {
118 #[inline]
120 pub fn with_capacity(capacity: usize) -> Self {
121 Self(RefCell::new(LruCache::new(capacity)))
122 }
123
124 #[inline]
125 fn set_capacity(&self, capacity: usize) {
126 self.0.borrow_mut().set_capacity(capacity);
127 }
128
129 fn get<'conn>(
137 &'conn self,
138 conn: &'conn Connection,
139 sql: &str,
140 ) -> Result<CachedStatement<'conn>> {
141 let trimmed = sql.trim();
142 let mut cache = self.0.borrow_mut();
143 let stmt = match cache.remove(trimmed) {
144 Some(raw_stmt) => Ok(Statement::new(conn, raw_stmt)),
145 None => conn.prepare_with_flags(trimmed, PrepFlags::SQLITE_PREPARE_PERSISTENT),
146 };
147 stmt.map(|mut stmt| {
148 stmt.stmt.set_statement_cache_key(trimmed);
149 CachedStatement::new(stmt, self)
150 })
151 }
152
153 fn cache_stmt(&self, mut stmt: RawStatement) {
155 if stmt.is_null() {
156 return;
157 }
158 let mut cache = self.0.borrow_mut();
159 stmt.clear_bindings();
160 if let Some(sql) = stmt.statement_cache_key() {
161 cache.insert(sql, stmt);
162 } else {
163 debug_assert!(
164 false,
165 "bug in statement cache code, statement returned to cache that without key"
166 );
167 }
168 }
169
170 #[inline]
171 fn flush(&self) {
172 let mut cache = self.0.borrow_mut();
173 cache.clear();
174 }
175}
176
177#[cfg(test)]
178mod test {
179 #[cfg(all(target_family = "wasm", target_os = "unknown"))]
180 use wasm_bindgen_test::wasm_bindgen_test as test;
181
182 use super::StatementCache;
183 use crate::{Connection, Result};
184 use fallible_iterator::FallibleIterator;
185
186 impl StatementCache {
187 fn clear(&self) {
188 self.0.borrow_mut().clear();
189 }
190
191 fn len(&self) -> usize {
192 self.0.borrow().len()
193 }
194
195 fn capacity(&self) -> usize {
196 self.0.borrow().capacity()
197 }
198 }
199
200 #[test]
201 fn test_cache() -> Result<()> {
202 let db = Connection::open_in_memory()?;
203 let cache = &db.cache;
204 let initial_capacity = cache.capacity();
205 assert_eq!(0, cache.len());
206 assert!(initial_capacity > 0);
207
208 let sql = "PRAGMA schema_version";
209 {
210 let mut stmt = db.prepare_cached(sql)?;
211 assert_eq!(0, cache.len());
212 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
213 }
214 assert_eq!(1, cache.len());
215
216 {
217 let mut stmt = db.prepare_cached(sql)?;
218 assert_eq!(0, cache.len());
219 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
220 }
221 assert_eq!(1, cache.len());
222
223 cache.clear();
224 assert_eq!(0, cache.len());
225 assert_eq!(initial_capacity, cache.capacity());
226 Ok(())
227 }
228
229 #[test]
230 fn test_set_capacity() -> Result<()> {
231 let db = Connection::open_in_memory()?;
232 let cache = &db.cache;
233
234 let sql = "PRAGMA schema_version";
235 {
236 let mut stmt = db.prepare_cached(sql)?;
237 assert_eq!(0, cache.len());
238 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
239 }
240 assert_eq!(1, cache.len());
241
242 db.set_prepared_statement_cache_capacity(0);
243 assert_eq!(0, cache.len());
244
245 {
246 let mut stmt = db.prepare_cached(sql)?;
247 assert_eq!(0, cache.len());
248 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
249 }
250 assert_eq!(0, cache.len());
251
252 db.set_prepared_statement_cache_capacity(8);
253 {
254 let mut stmt = db.prepare_cached(sql)?;
255 assert_eq!(0, cache.len());
256 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
257 }
258 assert_eq!(1, cache.len());
259 Ok(())
260 }
261
262 #[test]
263 fn test_discard() -> Result<()> {
264 let db = Connection::open_in_memory()?;
265 let cache = &db.cache;
266
267 let sql = "PRAGMA schema_version";
268 {
269 let mut stmt = db.prepare_cached(sql)?;
270 assert_eq!(0, cache.len());
271 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
272 stmt.discard();
273 }
274 assert_eq!(0, cache.len());
275 Ok(())
276 }
277
278 #[test]
279 fn test_ddl() -> Result<()> {
280 let db = Connection::open_in_memory()?;
281 db.execute_batch(
282 r"
283 CREATE TABLE foo (x INT);
284 INSERT INTO foo VALUES (1);
285 ",
286 )?;
287
288 let sql = "SELECT * FROM foo";
289
290 {
291 let mut stmt = db.prepare_cached(sql)?;
292 assert_eq!(Ok(Some(1i32)), stmt.query([])?.map(|r| r.get(0)).next());
293 }
294
295 db.execute_batch(
296 r"
297 ALTER TABLE foo ADD COLUMN y INT;
298 UPDATE foo SET y = 2;
299 ",
300 )?;
301
302 {
303 let mut stmt = db.prepare_cached(sql)?;
304 assert_eq!(
305 Ok(Some((1i32, 2i32))),
306 stmt.query([])?.map(|r| Ok((r.get(0)?, r.get(1)?))).next()
307 );
308 }
309 Ok(())
310 }
311
312 #[test]
313 fn test_connection_close() -> Result<()> {
314 let conn = Connection::open_in_memory()?;
315 conn.prepare_cached("SELECT * FROM sqlite_master;")?;
316
317 conn.close().expect("connection not closed");
318 Ok(())
319 }
320
321 #[test]
322 fn test_cache_key() -> Result<()> {
323 let db = Connection::open_in_memory()?;
324 let cache = &db.cache;
325 assert_eq!(0, cache.len());
326
327 let sql = "PRAGMA schema_version; ";
329 {
330 let mut stmt = db.prepare_cached(sql)?;
331 assert_eq!(0, cache.len());
332 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
333 }
334 assert_eq!(1, cache.len());
335
336 {
337 let mut stmt = db.prepare_cached(sql)?;
338 assert_eq!(0, cache.len());
339 assert_eq!(0, stmt.query_row([], |r| r.get::<_, i64>(0))?);
340 }
341 assert_eq!(1, cache.len());
342 Ok(())
343 }
344
345 #[test]
346 fn test_empty_stmt() -> Result<()> {
347 let conn = Connection::open_in_memory()?;
348 conn.prepare_cached("")?;
349 Ok(())
350 }
351}