4 次代碼提交 4dd66e21f8 ... 884136f843

作者 SHA1 備註 提交日期
  myuan 884136f843 允许使用table.field 2 年之前
  myuan a3c9021347 update 2 年之前
  myuan e319e4ed95 优化错误输出 2 年之前
  myuan 7b2ba25f97 将select相关表检查改为vector 2 年之前
共有 6 個文件被更改,包括 185 次插入106 次删除
  1. 23 15
      run_sql_parser_test.py
  2. 34 62
      src/checker.cpp
  3. 7 5
      src/parser.y
  4. 74 0
      src/utils.cpp
  5. 19 2
      src/utils.h
  6. 28 22
      tests_config.py

+ 23 - 15
run_sql_parser_test.py

@@ -9,14 +9,16 @@ import os
 import tempfile
 import tests_config
 import importlib
+
 importlib.reload(tests_config)
 
-sql_parser_tests, sql_checker_tests = tests_config.sql_parser_tests, tests_config.sql_checker_tests
+sql_parser_tests, sql_checker_tests = (
+    tests_config.sql_parser_tests,
+    tests_config.sql_checker_tests,
+)
 
 
-async def run_and_output(
-    *args: str, timeout=10
-) -> tuple[bytes, bytes]:
+async def run_and_output(*args: str, timeout=10) -> tuple[bytes, bytes]:
     p = await subprocess.create_subprocess_exec(
         *args,
         stdout=subprocess.PIPE,
@@ -25,9 +27,10 @@ async def run_and_output(
     stdout, stderr = await asyncio.wait_for(p.communicate(), timeout=timeout)
     return stdout, stderr
 
+
 async def rebuild() -> bool:
-    print(datetime.now(), colored('rebuild...', "grey"))
-    stdout, _ = await run_and_output('xmake')
+    print(datetime.now(), colored("rebuild...", "grey"))
+    stdout, _ = await run_and_output("xmake")
     if b"error" in stdout:
         print(stdout.decode("utf-8"))
         print(datetime.now(), "-" * 40)
@@ -35,8 +38,9 @@ async def rebuild() -> bool:
     else:
         return True
 
+
 async def assert_sql(sql: str, expected: dict):
-    stdout, stderr = await run_and_output('xmake', 'run', "sql-parser", sql)
+    stdout, stderr = await run_and_output("xmake", "run", "sql-parser", sql)
 
     if b"error" in stdout:
         print(stdout.decode("utf-8"))
@@ -82,20 +86,22 @@ async def on_parser_modified():
 
 async def assert_checks():
     for sql, res in sql_checker_tests:
-        stdout, stderr = await run_and_output(
-            'xmake', 'run', "sql-checker", 
-            "-s", sql
-        )
+        stdout, stderr = await run_and_output("xmake", "run", "sql-checker", "-s", sql)
         print(sql, res)
         if res is True:
-            assert b'error' not in stdout, stdout.decode("utf-8")
-            assert b'error' not in stderr, stderr.decode('utf-8')
+            assert b"error" not in stdout, (
+                stdout.decode("utf-8") + "\n" + stderr.decode("utf-8")
+            )
+            assert b"error" not in stderr, (
+                stdout.decode("utf-8") + "\n" + stderr.decode("utf-8")
+            )
         elif isinstance(res, str):
-            res = res.encode('utf-8') 
+            res = res.encode("utf-8")
             assert res in stderr, stderr.decode("utf-8")
         else:
             assert False, f"{res} 不是合适的结果"
 
+
 async def on_checker_modified():
     print(datetime.now(), colored("run checker tests...", "yellow"))
     try:
@@ -118,7 +124,9 @@ async def watch_parser():
 
 
 async def watch_checker():
-    async for changes in awatch("./src/checker.cpp", "./src/checker.h", "./src/utils.h", "./src/utils.cpp"):
+    async for changes in awatch(
+        "./src/checker.cpp", "./src/checker.h", "./src/utils.h", "./src/utils.cpp"
+    ):
         if await rebuild():
             await on_checker_modified()
 

+ 34 - 62
src/checker.cpp

@@ -1,4 +1,5 @@
 #include <fmt/core.h>
+#include <fmt/ranges.h>
 #include <stdio.h>
 
 #include <CLI/CLI.hpp>
@@ -10,6 +11,7 @@ using json = nlohmann::json;
 using std::nullopt;
 using std::optional;
 using std::string;
+using std::vector;
 
 auto tables = ExistTables("./tables.json");
 
@@ -24,33 +26,7 @@ void create_table(const json& j) {
     tables.save();
 }
 
-optional<TableCol> get_select_column_define(const string table_name,
-                                           const json& select_col) {
-    auto type = type_of(select_col);
-    if (type == "identifier") {
-        auto cols = tables[table_name];
-        string colname = select_col["value"];
-        for (auto& col : cols) {
-            if (col.column_name == colname) {
-                return col;
-            }
-        }
-        return nullopt;
-    } else if (type == "int" || type == "string" || type == "float") {
-        return TableCol{.column_name = select_col["value"].dump(),
-                        .data_type = type,
-                        .type = "const",
-                        .primary_key = false};
-    } else if (type == "select_all_column") {
-        return TableCol{.column_name = "*",
-                        .data_type = "",
-                        .type = "select_all_column",
-                        .primary_key = false};
-    }
-    return nullopt;
-}
-
-void check_where_clause(const string table_name, const json& j) {
+void check_where_clause(const vector<string> table_names, const json& j) {
     if (j.empty()) {
         return;
     }
@@ -58,34 +34,34 @@ void check_where_clause(const string table_name, const json& j) {
     auto type = type_of(j);
     auto left = j["left"];
     auto right = j["right"];
-    for(const auto curr_branch_name : {"left", "right"}) {
+    for (const auto curr_branch_name : {"left", "right"}) {
         auto curr_branch = j[curr_branch_name];
         if (left.contains("value")) {
             // 说明到叶节点了
-            auto col_def = get_select_column_define(table_name, curr_branch);
+            auto col_def = tables.get_select_column_define(table_names, curr_branch);
             if (!col_def.has_value()) {
-                throw std::runtime_error(fmt::format(
-                    "column `{}` not exists in `{}`\n",
-                    curr_branch["value"].dump(2), table_name));
+                throw std::runtime_error(
+                    fmt::format("column `{}` not exists in `{}`\n",
+                                select_column_get_name(curr_branch),
+                                fmt::join(table_names, ", ")));
             }
         } else {
-            check_where_clause(table_name, curr_branch);
+            check_where_clause(table_names, curr_branch);
         }
     }
 }
 
-void check_select_table(const string table_name, const json& select_cols) {
-    for (auto& select_col : select_cols) {
-        if (select_col["type"] == "select_all_column") {
-            continue;
-        }
-        auto col_def =
-            get_select_column_define(table_name, select_col["target"]);
+void check_select_table(const vector<string> table_names,
+                        const json& select_cols) {
+    for (let& select_col : select_cols) {
+        let col_def =
+            tables.get_select_column_define(table_names, select_col["target"]);
 
         if (!col_def.has_value()) {
             throw std::runtime_error(
                 fmt::format("column `{}` not exists in `{}`\n",
-                            select_col["target"]["value"].dump(2), table_name));
+                            select_column_get_name(select_col["target"]),
+                            fmt::join(table_names, ", ")));
         }
     }
 }
@@ -96,29 +72,25 @@ void process_sql(json& j) {
     if (type == "create_stmt") {
         // 创建表
         create_table(j);
-    } else {
-        if (type == "drop_stmt") {
-            if (tables.exists(j["table_name"])) {
-                tables.remove(j["table_name"]);
-                tables.save();
-            }
-            return;
+    } else if (type == "drop_stmt") {
+        if (tables.exists(j["table_name"])) {
+            tables.remove(j["table_name"]);
+            tables.save();
         }
-
-        // 其余都会要求表存在
-        let table_name = table_name_of(j);
-        if (!tables.exists(table_name)) {
-            fmt::print("error: table `{}` does not exist\n", table_name);
-            return;
-        }
-
-        if (type == "select_stmt") {
-            check_select_table(table_name, j["select_cols"]);
-            check_where_clause(table_name, j["where"]);
-        } else {
-            throw std::runtime_error(fmt::format(
-                "Unknown expression type `{}` total: \n{}", type, j.dump(2)));
+        return;
+    } else if (type == "select_stmt") {
+        let table_names = table_names_of(j);
+        for (let table_name : table_names) {
+            if (!tables.exists(table_name)) {
+                fmt::print("error: table `{}` does not exist\n", table_name);
+                return;
+            }
         }
+        check_select_table(table_names, j["select_cols"]);
+        check_where_clause(table_names, j["where"]);
+    } else {
+        throw std::runtime_error(fmt::format(
+            "Unknown expression type `{}` total: \n{}", type, j.dump(2)));
     }
 }
 

+ 7 - 5
src/parser.y

@@ -386,16 +386,12 @@ table_name: IDENTIFIER {$$=$1;}
 ;
 
 table_name_list: table_name {
-		MEET_VAR(table_name, $1);
 		cJSON* node = cJSON_CreateArray();
 		cJSON_AddItemToArray(node, cJSON_CreateString($1));
 		$$=node;
 	}
 	| table_name_list ',' table_name {
-		MEET_VAR(table_name, $3);
-		MEET_VAR(table_name_list, $1);
-
-		cJSON_AddItemToArray($1, $3);
+		cJSON_AddItemToArray($1, cJSON_CreateString($3));
 		$$=$1;
 	}
 ;
@@ -404,6 +400,7 @@ select_stmt: SELECT select_items FROM table_name_list op_join op_where_expr {
 		cJSON* node = cJSON_CreateObject();
 		cJSON_AddStringToObject(node, "type", "select_stmt");
 		cJSON_AddItemToObject(node, "select_cols", $2);
+		MEET_JSON(table_names, $4);
 		cJSON_AddItemToObject(node, "table_names", $4);
 		if ($5 != NULL) {
 			cJSON_AddItemToObject(node, "join_options", $5);
@@ -440,6 +437,11 @@ select_item: single_expr {
 	| '*' {
 		cJSON* node = cJSON_CreateObject();
 		cJSON_AddStringToObject(node, "type", "select_all_column");
+
+		cJSON* node_select_all = cJSON_CreateObject();
+		cJSON_AddStringToObject(node_select_all, "type", "select_all_column");
+		cJSON_AddStringToObject(node_select_all, "value", "select_all_column");
+		cJSON_AddItemToObject(node, "target", node_select_all);
 		$$=node;
 	}
 ;

+ 74 - 0
src/utils.cpp

@@ -2,6 +2,7 @@
 
 #include <fmt/core.h>
 
+#include <algorithm>
 #include <fstream>
 #include <map>
 #include <nlohmann/json.hpp>
@@ -9,6 +10,8 @@
 #include <vector>
 
 using json = nlohmann::json;
+using std::string;
+using std::vector;
 
 SQLParserRes parse_sql(const std::string& sql) {
     SQLParserRes res;
@@ -61,8 +64,36 @@ void ExistTables::set(const std::string& table_name, const TableCols& cols) {
     tables[table_name] = cols;
 }
 TableCols ExistTables::operator[](const std::string& table_name) {
+    return this->get(table_name);
+}
+TableCols ExistTables::operator[](const std::vector<std::string>& table_names) {
+    return this->get(table_names);
+}
+TableCols ExistTables::get(const std::string& table_name) {
     return tables[table_name];
 }
+TableCols ExistTables::get(const std::vector<std::string>& table_names) {
+    TableCols res = {};
+    for (let table_name : table_names) {
+        let cols = this->get(table_name);
+        res.assign(cols.begin(), cols.end());
+    }
+    return res;
+}
+
+optional<TableCol> ExistTables::find_col_in_tables(
+    const std::vector<std::string>& table_names, const std::string col_name) {
+    for (let table_name : table_names) {
+        let cols = this->get(table_name);
+        for (let col : cols) {
+            if (col.column_name == col_name) {
+                return col;
+            }
+        }
+    }
+    return nullopt;
+}
+
 void ExistTables::remove(const std::string& table_name) {
     tables.erase(table_name);
 }
@@ -74,3 +105,46 @@ void ExistTables::save() {
     std::ofstream f(table_file_name);
     f << j.dump(2);
 }
+
+optional<TableCol> ExistTables::get_select_column_define(
+    const vector<string> table_names, const json& select_col) {
+    auto type = type_of(select_col);
+    if (type == "identifier") {
+        return this->find_col_in_tables(table_names, select_col["value"]);
+    } else if (type == "int" || type == "string" || type == "float") {
+        return TableCol{.column_name = select_col["value"].dump(),
+                        .type = "const",
+                        .data_type = type,
+                        .primary_key = false};
+    } else if (type == "select_all_column") {
+        return TableCol{.column_name = "*",
+                        .type = "select_all_column",
+                        .data_type = "",
+                        .primary_key = false};
+    } else if (type == "table_field") {
+        let curr_table_name = select_col["table"];
+        let curr_field_name = select_col["field"];
+
+        if (std::find(table_names.begin(), table_names.end(),
+                      curr_table_name) == table_names.end()) {
+            return nullopt;
+        }
+        return this->find_col_in_tables({curr_table_name}, curr_field_name);
+    }
+    return nullopt;
+}
+
+string select_column_get_name(const json& j) {
+    let type = j["type"];
+    if (type == "identifier") {
+        return j["value"];
+    } else if (type == "select_all_column") {
+        return "select_all_column";
+    } else if (type == "int" || type == "string" || type == "float") {
+        return fmt::format("{}:{}", type, j["value"]);
+    } else if (type == "table_field") {
+        return fmt::format("{}.{}", j["table"], j["field"]);
+    }
+    throw std::runtime_error(
+        fmt::format("无法将 json 视为 select_column: {}", j.dump(2)));
+}

+ 19 - 2
src/utils.h

@@ -2,6 +2,7 @@
 
 #include <map>
 #include <nlohmann/json.hpp>
+#include <optional>
 #include <vector>
 
 #define let const auto
@@ -22,6 +23,8 @@ struct TableCol {
 };
 
 using TableCols = std::vector<TableCol>;
+using std::nullopt;
+using std::optional;
 
 SQLParserRes parse_sql(const std::string& sql);
 
@@ -35,14 +38,28 @@ class ExistTables {
     bool exists(const std::string& table_name);
     void set(const std::string& table_name, const TableCols& cols);
     TableCols operator[](const std::string& table_name);
+    TableCols operator[](const std::vector<std::string>& table_names);
+    TableCols get(const std::string& table_name);
+    TableCols get(const std::vector<std::string>& table_names);
+
+    optional<TableCol> find_col_in_tables(
+        const std::vector<std::string>& table_names,
+        const std::string col_name);
     void remove(const std::string& table_name);
     void save();
+
+    optional<TableCol> get_select_column_define(
+        const std::vector<std::string> table_names,
+        const nlohmann::json& select_col);
 };
 
 inline auto table_name_of(const nlohmann::json& j) {
     return j.value("table_name", "none");
 }
-inline auto type_of(const nlohmann::json& j) {
-    return j.value("type", "none");
+inline auto table_names_of(const nlohmann::json& j) {
+    // return j.value("table_names", json::array());
+    return j["table_names"];
 }
+inline auto type_of(const nlohmann::json& j) { return j.value("type", "none"); }
 
+std::string select_column_get_name(const nlohmann::json& j);

+ 28 - 22
tests_config.py

@@ -275,14 +275,7 @@ sql_parser_tests = [
     ),
     (
         "select * from t2;",
-        [
-            {
-                "type": "select_stmt",
-                "select_cols": [{"type": "select_all_column"}],
-                "table_names": ["t2"],
-                "where": {},
-            }
-        ],
+        [{'type': 'select_stmt', 'select_cols': [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}], 'table_names': ['t2'], 'where': {}}]
     ),
     (
         "select c2 as t from t2 where col1>2;",
@@ -485,20 +478,33 @@ sql_parser_tests = [
             }
         ],
     ),
-    ('drop table t1;', [{"type": "drop_stmt", "table_name": "t1"}]),
+    ("drop table t1;", [{"type": "drop_stmt", "table_name": "t1"}]),
+    (
+        "select * from tb1, tb2;",
+        [
+            {
+                "type": "select_stmt",
+                "select_cols": [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}],
+                "table_names": ["tb1", "tb2"],
+                "where": {},
+            }
+        ],
+    ),
 ]
 
 sql_checker_tests = [
-    ('drop table person;', True),
-    ('create table person(name string, age int, classId int);', True),
-    ('select age from person;', True),
-    ('select * from person;', True),
-    ('select gender from person;', 'column `"gender"` not exists in `person`'),
-    ('select 123 from person;', True),
-
-    ('drop table class;', True),
-    ('create table class (id int, grade int, faculty string);', True),
-    ('select * from class where grade = 2 and faculty = \'Computer Science\';', True),
-    ('select * from class where grade = 2 and count=33;', 'column `"count"` not exists in `class`'),
-
-]
+    ("drop table person;", True),
+    ("create table person(name string, age int, classId int);", True),
+    ("select age from person;", True),
+    ("select * from person;", True),
+    ("select gender from person;", 'column `gender` not exists in `person`'),
+    ("select 123 from person;", True),
+    ("drop table class;", True),
+    ("create table class (id int, grade int, faculty string);", True),
+    ("select * from class where grade = 2 and faculty = 'Computer Science';", True),
+    (
+        "select * from class where grade = 2 and count=33;",
+        'column `count` not exists in `class`',
+    ),
+    ("select age, class.grade from class, person;", True),
+]