Browse Source

允许使用table.field

myuan 2 years ago
parent
commit
884136f843
6 changed files with 74 additions and 61 deletions
  1. 23 15
      run_sql_parser_test.py
  2. 5 27
      src/checker.cpp
  3. 7 5
      src/parser.y
  4. 32 4
      src/utils.cpp
  5. 5 1
      src/utils.h
  6. 2 9
      tests_config.py

+ 23 - 15
run_sql_parser_test.py

@@ -9,14 +9,16 @@ import os
 import tempfile
 import tests_config
 import importlib
+
 importlib.reload(tests_config)
 
-sql_parser_tests, sql_checker_tests = tests_config.sql_parser_tests, tests_config.sql_checker_tests
+sql_parser_tests, sql_checker_tests = (
+    tests_config.sql_parser_tests,
+    tests_config.sql_checker_tests,
+)
 
 
-async def run_and_output(
-    *args: str, timeout=10
-) -> tuple[bytes, bytes]:
+async def run_and_output(*args: str, timeout=10) -> tuple[bytes, bytes]:
     p = await subprocess.create_subprocess_exec(
         *args,
         stdout=subprocess.PIPE,
@@ -25,9 +27,10 @@ async def run_and_output(
     stdout, stderr = await asyncio.wait_for(p.communicate(), timeout=timeout)
     return stdout, stderr
 
+
 async def rebuild() -> bool:
-    print(datetime.now(), colored('rebuild...', "grey"))
-    stdout, _ = await run_and_output('xmake')
+    print(datetime.now(), colored("rebuild...", "grey"))
+    stdout, _ = await run_and_output("xmake")
     if b"error" in stdout:
         print(stdout.decode("utf-8"))
         print(datetime.now(), "-" * 40)
@@ -35,8 +38,9 @@ async def rebuild() -> bool:
     else:
         return True
 
+
 async def assert_sql(sql: str, expected: dict):
-    stdout, stderr = await run_and_output('xmake', 'run', "sql-parser", sql)
+    stdout, stderr = await run_and_output("xmake", "run", "sql-parser", sql)
 
     if b"error" in stdout:
         print(stdout.decode("utf-8"))
@@ -82,20 +86,22 @@ async def on_parser_modified():
 
 async def assert_checks():
     for sql, res in sql_checker_tests:
-        stdout, stderr = await run_and_output(
-            'xmake', 'run', "sql-checker", 
-            "-s", sql
-        )
+        stdout, stderr = await run_and_output("xmake", "run", "sql-checker", "-s", sql)
         print(sql, res)
         if res is True:
-            assert b'error' not in stdout, stdout.decode("utf-8")
-            assert b'error' not in stderr, stderr.decode('utf-8')
+            assert b"error" not in stdout, (
+                stdout.decode("utf-8") + "\n" + stderr.decode("utf-8")
+            )
+            assert b"error" not in stderr, (
+                stdout.decode("utf-8") + "\n" + stderr.decode("utf-8")
+            )
         elif isinstance(res, str):
-            res = res.encode('utf-8') 
+            res = res.encode("utf-8")
             assert res in stderr, stderr.decode("utf-8")
         else:
             assert False, f"{res} 不是合适的结果"
 
+
 async def on_checker_modified():
     print(datetime.now(), colored("run checker tests...", "yellow"))
     try:
@@ -118,7 +124,9 @@ async def watch_parser():
 
 
 async def watch_checker():
-    async for changes in awatch("./src/checker.cpp", "./src/checker.h", "./src/utils.h", "./src/utils.cpp"):
+    async for changes in awatch(
+        "./src/checker.cpp", "./src/checker.h", "./src/utils.h", "./src/utils.cpp"
+    ):
         if await rebuild():
             await on_checker_modified()
 

+ 5 - 27
src/checker.cpp

@@ -26,25 +26,6 @@ void create_table(const json& j) {
     tables.save();
 }
 
-optional<TableCol> get_select_column_define(const vector<string> table_names,
-                                            const json& select_col) {
-    auto type = type_of(select_col);
-    if (type == "identifier") {
-        return tables.find_col_in_tables(table_names, select_col["value"]);
-    } else if (type == "int" || type == "string" || type == "float") {
-        return TableCol{.column_name = select_col["value"].dump(),
-                        .type = "const",
-                        .data_type = type,
-                        .primary_key = false};
-    } else if (type == "select_all_column") {
-        return TableCol{.column_name = "*",
-                        .type = "select_all_column",
-                        .data_type = "",
-                        .primary_key = false};
-    }
-    return nullopt;
-}
-
 void check_where_clause(const vector<string> table_names, const json& j) {
     if (j.empty()) {
         return;
@@ -57,10 +38,10 @@ void check_where_clause(const vector<string> table_names, const json& j) {
         auto curr_branch = j[curr_branch_name];
         if (left.contains("value")) {
             // 说明到叶节点了
-            auto col_def = get_select_column_define(table_names, curr_branch);
+            auto col_def = tables.get_select_column_define(table_names, curr_branch);
             if (!col_def.has_value()) {
                 throw std::runtime_error(
-                    fmt::format("column `{}` not exists in `{}`\n", 
+                    fmt::format("column `{}` not exists in `{}`\n",
                                 select_column_get_name(curr_branch),
                                 fmt::join(table_names, ", ")));
             }
@@ -72,12 +53,9 @@ void check_where_clause(const vector<string> table_names, const json& j) {
 
 void check_select_table(const vector<string> table_names,
                         const json& select_cols) {
-    for (auto& select_col : select_cols) {
-        if (select_col["type"] == "select_all_column") {
-            continue;
-        }
-        auto col_def =
-            get_select_column_define(table_names, select_col["target"]);
+    for (let& select_col : select_cols) {
+        let col_def =
+            tables.get_select_column_define(table_names, select_col["target"]);
 
         if (!col_def.has_value()) {
             throw std::runtime_error(

+ 7 - 5
src/parser.y

@@ -386,16 +386,12 @@ table_name: IDENTIFIER {$$=$1;}
 ;
 
 table_name_list: table_name {
-		MEET_VAR(table_name, $1);
 		cJSON* node = cJSON_CreateArray();
 		cJSON_AddItemToArray(node, cJSON_CreateString($1));
 		$$=node;
 	}
 	| table_name_list ',' table_name {
-		MEET_VAR(table_name, $3);
-		MEET_VAR(table_name_list, $1);
-
-		cJSON_AddItemToArray($1, $3);
+		cJSON_AddItemToArray($1, cJSON_CreateString($3));
 		$$=$1;
 	}
 ;
@@ -404,6 +400,7 @@ select_stmt: SELECT select_items FROM table_name_list op_join op_where_expr {
 		cJSON* node = cJSON_CreateObject();
 		cJSON_AddStringToObject(node, "type", "select_stmt");
 		cJSON_AddItemToObject(node, "select_cols", $2);
+		MEET_JSON(table_names, $4);
 		cJSON_AddItemToObject(node, "table_names", $4);
 		if ($5 != NULL) {
 			cJSON_AddItemToObject(node, "join_options", $5);
@@ -440,6 +437,11 @@ select_item: single_expr {
 	| '*' {
 		cJSON* node = cJSON_CreateObject();
 		cJSON_AddStringToObject(node, "type", "select_all_column");
+
+		cJSON* node_select_all = cJSON_CreateObject();
+		cJSON_AddStringToObject(node_select_all, "type", "select_all_column");
+		cJSON_AddStringToObject(node_select_all, "value", "select_all_column");
+		cJSON_AddItemToObject(node, "target", node_select_all);
 		$$=node;
 	}
 ;

+ 32 - 4
src/utils.cpp

@@ -11,6 +11,7 @@
 
 using json = nlohmann::json;
 using std::string;
+using std::vector;
 
 SQLParserRes parse_sql(const std::string& sql) {
     SQLParserRes res;
@@ -81,11 +82,10 @@ TableCols ExistTables::get(const std::vector<std::string>& table_names) {
 }
 
 optional<TableCol> ExistTables::find_col_in_tables(
-        const std::vector<std::string>& table_names,
-        const std::string col_name) {
+    const std::vector<std::string>& table_names, const std::string col_name) {
     for (let table_name : table_names) {
         let cols = this->get(table_name);
-        for(let col : cols) {
+        for (let col : cols) {
             if (col.column_name == col_name) {
                 return col;
             }
@@ -106,7 +106,35 @@ void ExistTables::save() {
     f << j.dump(2);
 }
 
-string select_column_get_name(const json &j) {
+optional<TableCol> ExistTables::get_select_column_define(
+    const vector<string> table_names, const json& select_col) {
+    auto type = type_of(select_col);
+    if (type == "identifier") {
+        return this->find_col_in_tables(table_names, select_col["value"]);
+    } else if (type == "int" || type == "string" || type == "float") {
+        return TableCol{.column_name = select_col["value"].dump(),
+                        .type = "const",
+                        .data_type = type,
+                        .primary_key = false};
+    } else if (type == "select_all_column") {
+        return TableCol{.column_name = "*",
+                        .type = "select_all_column",
+                        .data_type = "",
+                        .primary_key = false};
+    } else if (type == "table_field") {
+        let curr_table_name = select_col["table"];
+        let curr_field_name = select_col["field"];
+
+        if (std::find(table_names.begin(), table_names.end(),
+                      curr_table_name) == table_names.end()) {
+            return nullopt;
+        }
+        return this->find_col_in_tables({curr_table_name}, curr_field_name);
+    }
+    return nullopt;
+}
+
+string select_column_get_name(const json& j) {
     let type = j["type"];
     if (type == "identifier") {
         return j["value"];

+ 5 - 1
src/utils.h

@@ -47,6 +47,10 @@ class ExistTables {
         const std::string col_name);
     void remove(const std::string& table_name);
     void save();
+
+    optional<TableCol> get_select_column_define(
+        const std::vector<std::string> table_names,
+        const nlohmann::json& select_col);
 };
 
 inline auto table_name_of(const nlohmann::json& j) {
@@ -58,4 +62,4 @@ inline auto table_names_of(const nlohmann::json& j) {
 }
 inline auto type_of(const nlohmann::json& j) { return j.value("type", "none"); }
 
-std::string select_column_get_name(const nlohmann::json &j);
+std::string select_column_get_name(const nlohmann::json& j);

+ 2 - 9
tests_config.py

@@ -275,14 +275,7 @@ sql_parser_tests = [
     ),
     (
         "select * from t2;",
-        [
-            {
-                "type": "select_stmt",
-                "select_cols": [{"type": "select_all_column"}],
-                "table_names": ["t2"],
-                "where": {},
-            }
-        ],
+        [{'type': 'select_stmt', 'select_cols': [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}], 'table_names': ['t2'], 'where': {}}]
     ),
     (
         "select c2 as t from t2 where col1>2;",
@@ -491,7 +484,7 @@ sql_parser_tests = [
         [
             {
                 "type": "select_stmt",
-                "select_cols": [{"type": "select_all_column"}],
+                "select_cols": [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}],
                 "table_names": ["tb1", "tb2"],
                 "where": {},
             }