浏览代码

处理join下的字段检测

myuan 2 年之前
父节点
当前提交
a09d510036
共有 5 个文件被更改,包括 95 次插入10 次删除
  1. 3 2
      run_sql_parser_test.py
  2. 25 3
      src/checker.cpp
  3. 7 0
      src/utils.cpp
  4. 2 0
      src/utils.h
  5. 58 5
      tests_config.py

+ 3 - 2
run_sql_parser_test.py

@@ -107,8 +107,9 @@ async def on_checker_modified():
     try:
         await assert_checks()
     except Exception as e:
-        print(e)
-    print(datetime.now(), colored("all checker tests right!", "green"))
+        print(colored(e, "red"))
+    else:
+        print(datetime.now(), colored("all checker tests right!", "green"))
 
 
 async def restart():

+ 25 - 3
src/checker.cpp

@@ -36,9 +36,10 @@ void check_where_clause(const vector<string> table_names, const json& j) {
     auto right = j["right"];
     for (const auto curr_branch_name : {"left", "right"}) {
         auto curr_branch = j[curr_branch_name];
-        if (left.contains("value")) {
+        if (is_column(curr_branch)) {
             // 说明到叶节点了
-            auto col_def = tables.get_select_column_define(table_names, curr_branch);
+            auto col_def =
+                tables.get_select_column_define(table_names, curr_branch);
             if (!col_def.has_value()) {
                 throw std::runtime_error(
                     fmt::format("column `{}` not exists in `{}`\n",
@@ -66,6 +67,21 @@ void check_select_table(const vector<string> table_names,
     }
 }
 
+vector<string> check_join_table(const vector<string> table_names,
+                                const json& join_options) {
+    let join_table_name = join_options["join_with"]["value"];
+    auto new_table_names = vector(table_names);
+    new_table_names.push_back(join_table_name);
+
+    if (!tables.exists(join_table_name)) {
+        throw std::runtime_error(
+            fmt::format("table `{}` not exists\n", join_table_name));
+    }
+
+    check_where_clause(new_table_names, join_options["on"]);
+    return new_table_names;
+}
+
 void process_sql(json& j) {
     let type = j.value("type", "none");
 
@@ -79,14 +95,20 @@ void process_sql(json& j) {
         }
         return;
     } else if (type == "select_stmt") {
-        let table_names = table_names_of(j);
+        vector<string> table_names = table_names_of(j);
         for (let table_name : table_names) {
             if (!tables.exists(table_name)) {
                 fmt::print("error: table `{}` does not exist\n", table_name);
                 return;
             }
         }
+        if (j.contains("join_options")) {
+            auto new_tables = check_join_table(table_names, j["join_options"]);
+            table_names = new_tables;
+        }
+
         check_select_table(table_names, j["select_cols"]);
+
         check_where_clause(table_names, j["where"]);
     } else {
         throw std::runtime_error(fmt::format(

+ 7 - 0
src/utils.cpp

@@ -148,3 +148,10 @@ string select_column_get_name(const json& j) {
     throw std::runtime_error(
         fmt::format("无法将 json 视为 select_column: {}", j.dump(2)));
 }
+
+bool is_column(const json& j) {
+    let type = j.value("type", "none");
+    return type == "identifier" || type == "int" || type == "string" ||
+           type == "float" || type == "select_all_column" ||
+           type == "table_field";
+}

+ 2 - 0
src/utils.h

@@ -63,3 +63,5 @@ inline auto table_names_of(const nlohmann::json& j) {
 inline auto type_of(const nlohmann::json& j) { return j.value("type", "none"); }
 
 std::string select_column_get_name(const nlohmann::json& j);
+
+bool is_column(const nlohmann::json& j);

+ 58 - 5
tests_config.py

@@ -275,7 +275,22 @@ sql_parser_tests = [
     ),
     (
         "select * from t2;",
-        [{'type': 'select_stmt', 'select_cols': [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}], 'table_names': ['t2'], 'where': {}}]
+        [
+            {
+                "type": "select_stmt",
+                "select_cols": [
+                    {
+                        "type": "select_all_column",
+                        "target": {
+                            "type": "select_all_column",
+                            "value": "select_all_column",
+                        },
+                    }
+                ],
+                "table_names": ["t2"],
+                "where": {},
+            }
+        ],
     ),
     (
         "select c2 as t from t2 where col1>2;",
@@ -484,7 +499,15 @@ sql_parser_tests = [
         [
             {
                 "type": "select_stmt",
-                "select_cols": [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}],
+                "select_cols": [
+                    {
+                        "type": "select_all_column",
+                        "target": {
+                            "type": "select_all_column",
+                            "value": "select_all_column",
+                        },
+                    }
+                ],
                 "table_names": ["tb1", "tb2"],
                 "where": {},
             }
@@ -497,16 +520,46 @@ sql_checker_tests = [
     ("create table person(name string, age int, classId int);", True),
     ("select age from person;", True),
     ("select * from person;", True),
-    ("select gender from person;", 'column `gender` not exists in `person`'),
+    ("select gender from person;", "column `gender` not exists in `person`"),
     ("select 123 from person;", True),
     ("drop table class;", True),
     ("create table class (id int, grade int, faculty string);", True),
     ("select * from class where grade = 2 and faculty = 'Computer Science';", True),
     (
         "select * from class where grade = 2 and count=33;",
-        'column `count` not exists in `class`',
+        "column `count` not exists in `class`",
     ),
     ("select age, class.grade from class, person;", True),
-    ("select age, person.grade from class, person;", 'column `person.grade` not exists in `class, person`'),
+    (
+        "select age, person.grade from class, person;",
+        "column `person.grade` not exists in `class, person`",
+    ),
+    (
+        """SELECT person.name, grade, faculty
+            FROM person
+            WHERE name = '张三' and classId IN (
+                SELECT *
+                FROM class
+                WHERE class.id = classId
+            );
+        """,
+        'column `grade` not exists in `person`',
+    ),
+    (
+        """select person.name 
+        from person join class 
+        on class.id=person.classId and person.grade=2
+        where age=22;
+        """,
+        "column `person.grade` not exists in `person, class`",
+    ),
+        (
+        """select person.name 
+        from person join class 
+        on class.id=person.classId and class.grade=2
+        where age=22;
+        """,
+        True
+    ),
 
 ]