1use std::ops::Deref;
4
5use crate::ffi;
6use crate::types::{ToSql, ToSqlOutput, ValueRef};
7use crate::{Connection, Result, Row};
8
9pub struct Sql {
10 buf: String,
11}
12
13impl Sql {
14 pub fn new() -> Self {
15 Self { buf: String::new() }
16 }
17
18 pub fn push_pragma(&mut self, schema_name: Option<&str>, pragma_name: &str) -> Result<()> {
19 self.push_keyword("PRAGMA")?;
20 self.push_space();
21 if let Some(schema_name) = schema_name {
22 self.push_schema_name(schema_name);
23 self.push_dot();
24 }
25 self.push_keyword(pragma_name)
26 }
27
28 pub fn push_keyword(&mut self, keyword: &str) -> Result<()> {
29 if !keyword.is_empty() && is_identifier(keyword) {
30 self.buf.push_str(keyword);
31 Ok(())
32 } else {
33 Err(err!(ffi::SQLITE_MISUSE, "Invalid keyword \"{keyword}\""))
34 }
35 }
36
37 pub fn push_schema_name(&mut self, schema_name: &str) {
38 self.push_identifier(schema_name);
39 }
40
41 pub fn push_identifier(&mut self, s: &str) {
42 if is_identifier(s) {
43 self.buf.push_str(s);
44 } else {
45 self.wrap_and_escape(s, '"');
46 }
47 }
48
49 pub fn push_value(&mut self, value: &dyn ToSql) -> Result<()> {
50 let value = value.to_sql()?;
51 let value = match value {
52 ToSqlOutput::Borrowed(v) => v,
53 ToSqlOutput::Owned(ref v) => ValueRef::from(v),
54 #[cfg(any(feature = "blob", feature = "functions", feature = "pointer"))]
55 _ => {
56 return Err(err!(ffi::SQLITE_MISUSE, "Unsupported value \"{value:?}\""));
57 }
58 };
59 match value {
60 ValueRef::Integer(i) => {
61 self.push_int(i);
62 }
63 ValueRef::Real(r) => {
64 self.push_real(r);
65 }
66 ValueRef::Text(s) => {
67 let s = std::str::from_utf8(s)?;
68 self.push_string_literal(s);
69 }
70 _ => {
71 return Err(err!(ffi::SQLITE_MISUSE, "Unsupported value \"{value:?}\""));
72 }
73 };
74 Ok(())
75 }
76
77 pub fn push_string_literal(&mut self, s: &str) {
78 self.wrap_and_escape(s, '\'');
79 }
80
81 pub fn push_int(&mut self, i: i64) {
82 self.buf.push_str(&i.to_string());
83 }
84
85 pub fn push_real(&mut self, f: f64) {
86 self.buf.push_str(&f.to_string());
87 }
88
89 pub fn push_space(&mut self) {
90 self.buf.push(' ');
91 }
92
93 pub fn push_dot(&mut self) {
94 self.buf.push('.');
95 }
96
97 pub fn push_equal_sign(&mut self) {
98 self.buf.push('=');
99 }
100
101 pub fn open_brace(&mut self) {
102 self.buf.push('(');
103 }
104
105 pub fn close_brace(&mut self) {
106 self.buf.push(')');
107 }
108
109 pub fn as_str(&self) -> &str {
110 &self.buf
111 }
112
113 fn wrap_and_escape(&mut self, s: &str, quote: char) {
114 self.buf.push(quote);
115 let chars = s.chars();
116 for ch in chars {
117 if ch == quote {
119 self.buf.push(ch);
120 }
121 self.buf.push(ch);
122 }
123 self.buf.push(quote);
124 }
125}
126
127impl Deref for Sql {
128 type Target = str;
129
130 fn deref(&self) -> &str {
131 self.as_str()
132 }
133}
134
135impl Connection {
136 pub fn pragma_query_value<T, F>(
144 &self,
145 schema_name: Option<&str>,
146 pragma_name: &str,
147 f: F,
148 ) -> Result<T>
149 where
150 F: FnOnce(&Row<'_>) -> Result<T>,
151 {
152 let mut query = Sql::new();
153 query.push_pragma(schema_name, pragma_name)?;
154 self.query_row(&query, [], f)
155 }
156
157 pub fn pragma_query<F>(
162 &self,
163 schema_name: Option<&str>,
164 pragma_name: &str,
165 mut f: F,
166 ) -> Result<()>
167 where
168 F: FnMut(&Row<'_>) -> Result<()>,
169 {
170 let mut query = Sql::new();
171 query.push_pragma(schema_name, pragma_name)?;
172 let mut stmt = self.prepare(&query)?;
173 let mut rows = stmt.query([])?;
174 while let Some(result_row) = rows.next()? {
175 let row = result_row;
176 f(row)?;
177 }
178 Ok(())
179 }
180
181 pub fn pragma<F, V>(
191 &self,
192 schema_name: Option<&str>,
193 pragma_name: &str,
194 pragma_value: V,
195 mut f: F,
196 ) -> Result<()>
197 where
198 F: FnMut(&Row<'_>) -> Result<()>,
199 V: ToSql,
200 {
201 let mut sql = Sql::new();
202 sql.push_pragma(schema_name, pragma_name)?;
203 sql.open_brace();
207 sql.push_value(&pragma_value)?;
208 sql.close_brace();
209 let mut stmt = self.prepare(&sql)?;
210 let mut rows = stmt.query([])?;
211 while let Some(result_row) = rows.next()? {
212 let row = result_row;
213 f(row)?;
214 }
215 Ok(())
216 }
217
218 pub fn pragma_update<V>(
223 &self,
224 schema_name: Option<&str>,
225 pragma_name: &str,
226 pragma_value: V,
227 ) -> Result<()>
228 where
229 V: ToSql,
230 {
231 let mut sql = Sql::new();
232 sql.push_pragma(schema_name, pragma_name)?;
233 sql.push_equal_sign();
237 sql.push_value(&pragma_value)?;
238 self.execute_batch(&sql)
239 }
240
241 pub fn pragma_update_and_check<F, T, V>(
245 &self,
246 schema_name: Option<&str>,
247 pragma_name: &str,
248 pragma_value: V,
249 f: F,
250 ) -> Result<T>
251 where
252 F: FnOnce(&Row<'_>) -> Result<T>,
253 V: ToSql,
254 {
255 let mut sql = Sql::new();
256 sql.push_pragma(schema_name, pragma_name)?;
257 sql.push_equal_sign();
261 sql.push_value(&pragma_value)?;
262 self.query_row(&sql, [], f)
263 }
264}
265
266fn is_identifier(s: &str) -> bool {
267 let chars = s.char_indices();
268 for (i, ch) in chars {
269 if i == 0 {
270 if !is_identifier_start(ch) {
271 return false;
272 }
273 } else if !is_identifier_continue(ch) {
274 return false;
275 }
276 }
277 true
278}
279
280fn is_identifier_start(c: char) -> bool {
281 c.is_ascii_uppercase() || c == '_' || c.is_ascii_lowercase() || c > '\x7F'
282}
283
284fn is_identifier_continue(c: char) -> bool {
285 c == '$'
286 || c.is_ascii_digit()
287 || c.is_ascii_uppercase()
288 || c == '_'
289 || c.is_ascii_lowercase()
290 || c > '\x7F'
291}
292
293#[cfg(test)]
294mod test {
295 #[cfg(all(target_family = "wasm", target_os = "unknown"))]
296 use wasm_bindgen_test::wasm_bindgen_test as test;
297
298 use super::Sql;
299 use crate::pragma;
300 use crate::{Connection, Result};
301
302 #[test]
303 fn pragma_query_value() -> Result<()> {
304 let db = Connection::open_in_memory()?;
305 let user_version: i32 = db.pragma_query_value(None, "user_version", |row| row.get(0))?;
306 assert_eq!(0, user_version);
307 Ok(())
308 }
309
310 #[test]
311 fn pragma_func_query_value() -> Result<()> {
312 let db = Connection::open_in_memory()?;
313 let user_version: i32 =
314 db.one_column("SELECT user_version FROM pragma_user_version", [])?;
315 assert_eq!(0, user_version);
316 Ok(())
317 }
318
319 #[test]
320 fn pragma_query_no_schema() -> Result<()> {
321 let db = Connection::open_in_memory()?;
322 let mut user_version = -1;
323 db.pragma_query(None, "user_version", |row| {
324 user_version = row.get(0)?;
325 Ok(())
326 })?;
327 assert_eq!(0, user_version);
328 Ok(())
329 }
330
331 #[test]
332 fn pragma_query_with_schema() -> Result<()> {
333 let db = Connection::open_in_memory()?;
334 let mut user_version = -1;
335 db.pragma_query(Some("main"), "user_version", |row| {
336 user_version = row.get(0)?;
337 Ok(())
338 })?;
339 assert_eq!(0, user_version);
340 Ok(())
341 }
342
343 #[test]
344 fn pragma() -> Result<()> {
345 let db = Connection::open_in_memory()?;
346 let mut columns = Vec::new();
347 db.pragma(None, "table_info", "sqlite_master", |row| {
348 let column: String = row.get(1)?;
349 columns.push(column);
350 Ok(())
351 })?;
352 assert_eq!(5, columns.len());
353 Ok(())
354 }
355
356 #[test]
357 fn pragma_func() -> Result<()> {
358 let db = Connection::open_in_memory()?;
359 let mut table_info = db.prepare("SELECT * FROM pragma_table_info(?1)")?;
360 let mut columns = Vec::new();
361 let mut rows = table_info.query(["sqlite_master"])?;
362
363 while let Some(row) = rows.next()? {
364 let column: String = row.get(1)?;
365 columns.push(column);
366 }
367 assert_eq!(5, columns.len());
368 Ok(())
369 }
370
371 #[test]
372 fn pragma_update() -> Result<()> {
373 let db = Connection::open_in_memory()?;
374 db.pragma_update(None, "user_version", 1)
375 }
376
377 #[test]
378 fn pragma_update_and_check() -> Result<()> {
379 let db = Connection::open_in_memory()?;
380 let journal_mode: String =
381 db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get(0))?;
382 assert!(
383 journal_mode == "off" || journal_mode == "memory",
384 "mode: {journal_mode:?}"
385 );
386 let mode =
388 db.pragma_update_and_check(None, "journal_mode", "OFF", |row| row.get::<_, String>(0))?;
389 assert!(mode == "off" || mode == "memory", "mode: {mode:?}");
390
391 let param: &dyn crate::ToSql = &"OFF";
392 let mode =
393 db.pragma_update_and_check(None, "journal_mode", param, |row| row.get::<_, String>(0))?;
394 assert!(mode == "off" || mode == "memory", "mode: {mode:?}");
395 Ok(())
396 }
397
398 #[test]
399 fn is_identifier() {
400 assert!(pragma::is_identifier("full"));
401 assert!(pragma::is_identifier("r2d2"));
402 assert!(!pragma::is_identifier("sp ce"));
403 assert!(!pragma::is_identifier("semi;colon"));
404 }
405
406 #[test]
407 fn double_quote() {
408 let mut sql = Sql::new();
409 sql.push_schema_name(r#"schema";--"#);
410 assert_eq!(r#""schema"";--""#, sql.as_str());
411 }
412
413 #[test]
414 fn wrap_and_escape() {
415 let mut sql = Sql::new();
416 sql.push_string_literal("value'; --");
417 assert_eq!("'value''; --'", sql.as_str());
418 }
419
420 #[test]
421 fn locking_mode() -> Result<()> {
422 let db = Connection::open_in_memory()?;
423 db.pragma_update(None, "locking_mode", "exclusive")?;
424 Ok(())
425 }
426}