[go: up one dir, main page]

Skip to main content

sql_docs/
ast.rs

1//! Parse SQL text into an AST (`sqlparser`) for downstream comment attachment.
2//!
3//! This module does not interpret semantics; it only produces an AST + file metadata.
4
5use std::path::{Path, PathBuf};
6
7use sqlparser::{
8    ast::Statement,
9    dialect::GenericDialect,
10    parser::{Parser, ParserError},
11};
12
13use crate::source::SqlSource;
14
15/// A single SQL file plus all [`Statement`].
16#[derive(Debug)]
17pub struct ParsedSqlFile {
18    file: SqlSource,
19    statements: Vec<Statement>,
20}
21
22impl ParsedSqlFile {
23    /// Parses a [`SqlSource`] into `sqlparser` [`Statement`] nodes.
24    ///
25    /// This is the AST layer used by the `comments` module to attach leading
26    /// comment spans to statements/columns.
27    ///
28    /// # Parameters
29    /// - `file`: the [`SqlSource`] to parse
30    ///
31    /// # Errors
32    /// - Returns [`ParserError`] if parsing fails
33    pub fn parse(file: SqlSource) -> Result<Self, ParserError> {
34        let dialect = GenericDialect {};
35        let statements = Parser::parse_sql(&dialect, file.content())?;
36        Ok(Self { file, statements })
37    }
38
39    /// Getter method for returning the [`SqlSource`]
40    #[must_use]
41    pub const fn file(&self) -> &SqlSource {
42        &self.file
43    }
44
45    /// Getter method for returning the current object's file's path
46    #[must_use]
47    pub fn path(&self) -> Option<&Path> {
48        self.file.path()
49    }
50
51    /// Getter that returns an [`PathBuf`] for the path rather than `&Path`
52    #[must_use]
53    pub fn path_into_path_buf(&self) -> Option<PathBuf> {
54        self.file.path_into_path_buf()
55    }
56
57    /// Getter for the file's content
58    #[must_use]
59    pub fn content(&self) -> &str {
60        self.file.content()
61    }
62
63    /// Getter method for returning the vector of all statements [`Statement`]
64    #[must_use]
65    pub fn statements(&self) -> &[Statement] {
66        &self.statements
67    }
68}
69
70/// Struct to contain the vector of parsed SQL files
71#[derive(Debug)]
72pub struct ParsedSqlFileSet {
73    files: Vec<ParsedSqlFile>,
74}
75
76impl ParsedSqlFileSet {
77    /// Method that parses a set of all members in a [`SqlSource`]
78    ///
79    /// # Parameters
80    /// - `set` the set of [`SqlSource`]
81    ///
82    /// # Errors
83    /// - [`ParserError`] is returned for any errors parsing
84    pub fn parse_all(set: Vec<SqlSource>) -> Result<Self, ParserError> {
85        let files = set.into_iter().map(ParsedSqlFile::parse).collect::<Result<Vec<_>, _>>()?;
86
87        Ok(Self { files })
88    }
89
90    /// Getter method for returning the vector of all [`ParsedSqlFile`]
91    #[must_use]
92    pub fn files(&self) -> &[ParsedSqlFile] {
93        &self.files
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use std::{env, fs};
100
101    use super::*;
102    use crate::source::SqlSource;
103
104    #[test]
105    fn parsed_sql_file_parses_single_statement() -> Result<(), Box<dyn std::error::Error>> {
106        let base = env::temp_dir().join("parsed_sql_file_single_stmt_test");
107        let _ = fs::remove_dir_all(&base);
108        fs::create_dir_all(&base)?;
109        let file_path = base.join("one.sql");
110        let sql = "CREATE TABLE users (id INTEGER PRIMARY KEY);";
111        fs::write(&file_path, sql)?;
112        let sql_file = SqlSource::from_path(&file_path)?;
113        let parsed = ParsedSqlFile::parse(sql_file)?;
114        assert_eq!(parsed.path(), Some(file_path.as_path()));
115        assert_eq!(parsed.content(), sql);
116        assert_eq!(parsed.statements().len(), 1);
117        let _ = fs::remove_dir_all(&base);
118        Ok(())
119    }
120
121    #[test]
122    fn parsed_sql_file_set_parses_multiple_files() -> Result<(), Box<dyn std::error::Error>> {
123        let base = env::temp_dir().join("parsed_sql_file_set_multi_test");
124        let _ = fs::remove_dir_all(&base);
125        fs::create_dir_all(&base)?;
126        let sub = base.join("subdir");
127        fs::create_dir_all(&sub)?;
128        let file1 = base.join("one.sql");
129        let file2 = sub.join("two.sql");
130        let sql1 = "CREATE TABLE users (id INTEGER PRIMARY KEY);";
131        let sql2 = "CREATE TABLE posts (id INTEGER PRIMARY KEY);";
132        fs::write(&file1, sql1)?;
133        fs::write(&file2, sql2)?;
134        let set = SqlSource::sql_sources(&base, &[])?;
135        let parsed_set = ParsedSqlFileSet::parse_all(set)?;
136        let existing_files = parsed_set.files();
137        assert_eq!(existing_files.len(), 2);
138        for parsed in existing_files {
139            assert_eq!(parsed.statements().len(), 1);
140            let stmt = &parsed.statements()[0];
141            match stmt {
142                Statement::CreateTable { .. } => {}
143                other => panic!("expected CreateTable, got: {other:?}"),
144            }
145        }
146
147        let _ = fs::remove_dir_all(&base);
148        Ok(())
149    }
150}