5 Commits 884136f843 ... 71ef1afed7

Author SHA1 Message Date
  myuan 71ef1afed7 添加比较时类型检查 2 years ago
  myuan 53dca97270 添加操作符类型 2 years ago
  myuan a09d510036 处理join下的字段检测 2 years ago
  myuan 7bc042c23d del 2 years ago
  myuan ea45b0a03e del 2 years ago
6 changed files with 194 additions and 53 deletions
  1. 3 2
      run_sql_parser_test.py
  2. 45 10
      src/checker.cpp
  3. 10 36
      src/parser.y
  4. 21 0
      src/utils.cpp
  5. 11 0
      src/utils.h
  6. 104 5
      tests_config.py

+ 3 - 2
run_sql_parser_test.py

@@ -107,8 +107,9 @@ async def on_checker_modified():
     try:
         await assert_checks()
     except Exception as e:
-        print(e)
-    print(datetime.now(), colored("all checker tests right!", "green"))
+        print(colored(e, "red"))
+    else:
+        print(datetime.now(), colored("all checker tests right!", "green"))
 
 
 async def restart():

+ 45 - 10
src/checker.cpp

@@ -36,15 +36,36 @@ void check_where_clause(const vector<string> table_names, const json& j) {
     auto right = j["right"];
     for (const auto curr_branch_name : {"left", "right"}) {
         auto curr_branch = j[curr_branch_name];
-        if (left.contains("value")) {
+        if (is_column(curr_branch)) {
             // 说明到叶节点了
-            auto col_def = tables.get_select_column_define(table_names, curr_branch);
+            auto col_def =
+                tables.get_select_column_define(table_names, curr_branch);
             if (!col_def.has_value()) {
                 throw std::runtime_error(
                     fmt::format("column `{}` not exists in `{}`\n",
                                 select_column_get_name(curr_branch),
                                 fmt::join(table_names, ", ")));
+            } else {
+                if (curr_branch_name == "right" &&
+                    op_type_of(j) == OpType::OP_COMPARE) {
+                    // 如果左右分支都是普通列, 并且当前节点操作符是比较,
+                    // 那么应当判断类型是否一致 如果能走到right,
+                    // 说明left也是普通列
+                    let curr_left_col_def =
+                        tables.get_select_column_define(table_names, left);
+                    if (curr_left_col_def.value().data_type !=
+                        col_def.value().data_type) {
+                        throw std::runtime_error(
+                            fmt::format("column `{}` type is `{}`, but `{}` "
+                                        "type is `{}`, cannot `{}` \n",
+                                        select_column_get_name(left),
+                                        curr_left_col_def.value().data_type,
+                                        select_column_get_name(curr_branch),
+                                        col_def.value().data_type, j["type"]));
+                    }
+                }
             }
+
         } else {
             check_where_clause(table_names, curr_branch);
         }
@@ -66,6 +87,21 @@ void check_select_table(const vector<string> table_names,
     }
 }
 
+vector<string> check_join_table(const vector<string> table_names,
+                                const json& join_options) {
+    let join_table_name = join_options["join_with"]["value"];
+    auto new_table_names = vector(table_names);
+    new_table_names.push_back(join_table_name);
+
+    if (!tables.exists(join_table_name)) {
+        throw std::runtime_error(
+            fmt::format("table `{}` not exists\n", join_table_name));
+    }
+
+    check_where_clause(new_table_names, join_options["on"]);
+    return new_table_names;
+}
+
 void process_sql(json& j) {
     let type = j.value("type", "none");
 
@@ -79,14 +115,20 @@ void process_sql(json& j) {
         }
         return;
     } else if (type == "select_stmt") {
-        let table_names = table_names_of(j);
+        vector<string> table_names = table_names_of(j);
         for (let table_name : table_names) {
             if (!tables.exists(table_name)) {
                 fmt::print("error: table `{}` does not exist\n", table_name);
                 return;
             }
         }
+        if (j.contains("join_options")) {
+            auto new_tables = check_join_table(table_names, j["join_options"]);
+            table_names = new_tables;
+        }
+
         check_select_table(table_names, j["select_cols"]);
+
         check_where_clause(table_names, j["where"]);
     } else {
         throw std::runtime_error(fmt::format(
@@ -121,12 +163,5 @@ int main(int argc, char** argv) {
         process_sql(stmt);
     }
 
-    // auto t = res.out[0].value("type", "default");
-    // fmt::print("{}\n", t);
-
-    // tables.set("t1", {"a", "b", "c"});
-    // tables.set("t3", {"asdge", "safaw", "qwer"});
-    // fmt::print("tables.exists {}\n", tables.exists("t1"));
-    // tables.save();
     return 0;
 }

+ 10 - 36
src/parser.y

@@ -77,9 +77,8 @@ cJSON* jroot;
 %type <jv> create_definition create_col_list create_table_stmt data_value
 %type <jv> insert_stmt insert_list 
 %type <jv> update_stmt update_list single_assign_item
-%type <jv> where_condition_item identifier identifier_or_const_value 
+%type <jv> identifier identifier_or_const_value 
 %type <jv> delete_stmt select_stmt select_item select_items drop_stmt
-%type <jv> data_value_list identifier_or_const_value_or_const_value_list
 %type <jv> compare_expr single_expr where_expr logical_expr negative_expr op_where_expr expr_list contains_expr
 %type <jv> op_join table_field column_name
 %type <jv> table_name_list
@@ -252,8 +251,10 @@ op_where_expr: {$$=cJSON_CreateObject();}
 ;
 
 compare_expr: compare_expr bin_cmp_op compare_expr {
-	fprintf(stderr, "compare_expr %s\n", $2);
-	SIMPLE_OP_NODE($$, $2, $1, $3);}
+		fprintf(stderr, "compare_expr %s\n", $2);
+		SIMPLE_OP_NODE($$, $2, $1, $3);
+		cJSON_AddStringToObject($$, "op_type", "bin_cmp_op");
+	}
 	| single_expr {MEET_JSON(single_expr from compare_expr, $1); $$=$1;}
 	| '(' where_expr ')' {MEET(括号where_expr from compare_expr); $$=$2;}
 ;
@@ -265,11 +266,13 @@ negative_expr: NOT negative_expr {SIMPLE_OP_NODE_ONLY_LEFT($$, "非", $2);}
 
 contains_expr: identifier bin_contains_op '(' select_stmt ')' {
 		MEET_JSON(logical_expr bin_contains_op select_stmt, $2);
-		SIMPLE_OP_NODE($$, $2, $1, $4)
+		SIMPLE_OP_NODE($$, $2, $1, $4);
+		cJSON_AddStringToObject($$, "op_type", "bin_contains_op");
 	}
 	| identifier bin_contains_op '(' expr_list ')' {
 		MEET_VAR(logical_expr bin_contains_op expr_list, $2);
-		SIMPLE_OP_NODE($$, $2, $1, $4)
+		SIMPLE_OP_NODE($$, $2, $1, $4);
+		cJSON_AddStringToObject($$, "op_type", "bin_contains_op");
 	}
 	| negative_expr {MEET_JSON(negative_expr from contains_expr, $1); $$=$1;}
 ;
@@ -277,6 +280,7 @@ contains_expr: identifier bin_contains_op '(' select_stmt ')' {
 logical_expr: logical_expr bin_logical_op contains_expr {
 		fprintf(stderr, "logical_expr %s\n", $2);
 		SIMPLE_OP_NODE($$, $2, $1, $3);
+		cJSON_AddStringToObject($$, "op_type", "bin_logical_op");
 	}
 	| contains_expr {MEET_JSON(contains_expr from logical_expr, $1); $$=$1;}
 	| single_expr {MEET_JSON(single_expr from logical_expr, $1) $$=$1;}
@@ -314,36 +318,6 @@ identifier: IDENTIFIER {
 identifier_or_const_value: identifier {$$=$1;} | data_value {$$=$1;}
 ;
 
-data_value_list: data_value {
-		cJSON* node = cJSON_CreateArray();
-		cJSON_AddItemToArray(node, $1);
-		$$=node;
-	}
-	| data_value_list ',' data_value {
-		cJSON_AddItemToArray($1, $3);
-		$$=$1;
-	}
-;
-
-identifier_or_const_value_or_const_value_list: identifier_or_const_value{$$=$1;}
-	| data_value_list {$$=$1;}
-;
-
-where_condition_item: identifier_or_const_value bin_cmp_op identifier_or_const_value {
-		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", $2);
-		cJSON_AddItemToObject(node, "left", $1);
-		cJSON_AddItemToObject(node, "right", $3);
-		$$=node;
-	}
-	| identifier_or_const_value bin_contains_op '(' data_value_list ')' {
-		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", $2);
-		cJSON_AddItemToObject(node, "left", $1);
-		cJSON_AddItemToObject(node, "right", $4);
-		$$=node;
-	}
-;
 
 bin_cmp_op: '=' {$$ = "相等";}
 	| '>' {$$ = "大于";}

+ 21 - 0
src/utils.cpp

@@ -148,3 +148,24 @@ string select_column_get_name(const json& j) {
     throw std::runtime_error(
         fmt::format("无法将 json 视为 select_column: {}", j.dump(2)));
 }
+
+bool is_column(const json& j) {
+    let type = j.value("type", "none");
+    return type == "identifier" || type == "int" || type == "string" ||
+           type == "float" || type == "select_all_column" ||
+           type == "table_field";
+}
+
+OpType op_type_of(const nlohmann::json& j) {
+    // bin_contains_op, bin_logical_op, bin_cmp_op
+
+    let type = j.value("op_type", "none");
+    if (type == "bin_contains_op") {
+        return OpType::OP_CONTAINS;
+    } else if (type == "bin_logical_op") {
+        return OpType::OP_LOGIC;
+    } else if (type == "bin_cmp_op") {
+        return OpType::OP_COMPARE;
+    }
+    return OpType::OP_UNKNOWN;
+}

+ 11 - 0
src/utils.h

@@ -63,3 +63,14 @@ inline auto table_names_of(const nlohmann::json& j) {
 inline auto type_of(const nlohmann::json& j) { return j.value("type", "none"); }
 
 std::string select_column_get_name(const nlohmann::json& j);
+
+bool is_column(const nlohmann::json& j);
+
+enum OpType {
+    OP_COMPARE,
+    OP_LOGIC,
+    OP_CONTAINS,
+    OP_UNKNOWN,
+};
+
+OpType op_type_of(const nlohmann::json& j);

+ 104 - 5
tests_config.py

@@ -144,12 +144,15 @@ sql_parser_tests = [
                         "type": "相等",
                         "left": {"type": "identifier", "value": "col1"},
                         "right": {"type": "int", "value": 2},
+                        "op_type": "bin_cmp_op",
                     },
                     "right": {
                         "type": "相等",
                         "left": {"type": "identifier", "value": "col2"},
                         "right": {"type": "int", "value": 4},
+                        "op_type": "bin_cmp_op",
                     },
+                    "op_type": "bin_logical_op",
                 },
             }
         ],
@@ -186,6 +189,7 @@ sql_parser_tests = [
                                         "type": "相等",
                                         "left": {"type": "identifier", "value": "col1"},
                                         "right": {"type": "int", "value": 2},
+                                        "op_type": "bin_cmp_op",
                                     },
                                 },
                             },
@@ -194,13 +198,17 @@ sql_parser_tests = [
                             "type": "相等",
                             "left": {"type": "identifier", "value": "col2"},
                             "right": {"type": "int", "value": 4},
+                            "op_type": "bin_cmp_op",
                         },
+                        "op_type": "bin_logical_op",
                     },
                     "right": {
                         "type": "相等",
                         "left": {"type": "identifier", "value": "col3"},
                         "right": {"type": "identifier", "value": "col2"},
+                        "op_type": "bin_cmp_op",
                     },
+                    "op_type": "bin_logical_op",
                 },
             }
         ],
@@ -214,23 +222,28 @@ sql_parser_tests = [
                 "where": {
                     "type": "或",
                     "left": {
+                        "type": "且",
                         "left": {
                             "type": "相等",
                             "left": {"type": "identifier", "value": "c1"},
                             "right": {"type": "int", "value": 1},
+                            "op_type": "bin_cmp_op",
                         },
-                        "type": "且",
                         "right": {
                             "type": "相等",
                             "left": {"type": "identifier", "value": "c2"},
                             "right": {"type": "int", "value": 3},
+                            "op_type": "bin_cmp_op",
                         },
+                        "op_type": "bin_logical_op",
                     },
                     "right": {
                         "type": "相等",
                         "left": {"type": "identifier", "value": "c3"},
                         "right": {"type": "int", "value": 3},
+                        "op_type": "bin_cmp_op",
                     },
+                    "op_type": "bin_logical_op",
                 },
             }
         ],
@@ -249,6 +262,7 @@ sql_parser_tests = [
                             "type": "相等",
                             "left": {"type": "identifier", "value": "c1"},
                             "right": {"type": "int", "value": 1},
+                            "op_type": "bin_cmp_op",
                         },
                         "right": {
                             "type": "或",
@@ -256,26 +270,47 @@ sql_parser_tests = [
                                 "type": "相等",
                                 "left": {"type": "identifier", "value": "c2"},
                                 "right": {"type": "int", "value": 3},
+                                "op_type": "bin_cmp_op",
                             },
                             "right": {
                                 "type": "相等",
                                 "left": {"type": "identifier", "value": "c3"},
                                 "right": {"type": "int", "value": 3},
+                                "op_type": "bin_cmp_op",
                             },
+                            "op_type": "bin_logical_op",
                         },
+                        "op_type": "bin_logical_op",
                     },
                     "right": {
                         "type": "相等",
                         "left": {"type": "identifier", "value": "c4"},
                         "right": {"type": "string", "value": "'asd'"},
+                        "op_type": "bin_cmp_op",
                     },
+                    "op_type": "bin_logical_op",
                 },
             }
         ],
     ),
     (
         "select * from t2;",
-        [{'type': 'select_stmt', 'select_cols': [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}], 'table_names': ['t2'], 'where': {}}]
+        [
+            {
+                "type": "select_stmt",
+                "select_cols": [
+                    {
+                        "type": "select_all_column",
+                        "target": {
+                            "type": "select_all_column",
+                            "value": "select_all_column",
+                        },
+                    }
+                ],
+                "table_names": ["t2"],
+                "where": {},
+            }
+        ],
     ),
     (
         "select c2 as t from t2 where col1>2;",
@@ -294,6 +329,7 @@ sql_parser_tests = [
                     "type": "大于",
                     "left": {"type": "identifier", "value": "col1"},
                     "right": {"type": "int", "value": 2},
+                    "op_type": "bin_cmp_op",
                 },
             }
         ],
@@ -319,6 +355,7 @@ sql_parser_tests = [
                             {"type": "int", "value": 1},
                             {"type": "int", "value": 2},
                         ],
+                        "op_type": "bin_contains_op",
                     },
                     "right": {
                         "type": "包含于",
@@ -328,7 +365,9 @@ sql_parser_tests = [
                             {"type": "int", "value": 4},
                             {"type": "int", "value": 5},
                         ],
+                        "op_type": "bin_contains_op",
                     },
+                    "op_type": "bin_logical_op",
                 },
             }
         ],
@@ -376,8 +415,10 @@ sql_parser_tests = [
                                 "field": "Cno",
                             },
                             "right": {"type": "string", "value": "'81003'"},
+                            "op_type": "bin_cmp_op",
                         },
                     },
+                    "op_type": "bin_contains_op",
                 },
             }
         ],
@@ -414,12 +455,14 @@ sql_parser_tests = [
                             "field": "Sno",
                         },
                         "right": {"type": "table_field", "table": "SC", "field": "Sno"},
+                        "op_type": "bin_cmp_op",
                     },
                 },
                 "where": {
                     "type": "相等",
                     "left": {"type": "table_field", "table": "SC", "field": "Cno"},
                     "right": {"type": "string", "value": "'81003'"},
+                    "op_type": "bin_cmp_op",
                 },
             }
         ],
@@ -462,6 +505,7 @@ sql_parser_tests = [
                                 "table": "SC",
                                 "field": "Sno",
                             },
+                            "op_type": "bin_cmp_op",
                         },
                         "right": {
                             "type": "相等",
@@ -471,7 +515,9 @@ sql_parser_tests = [
                                 "field": "Cno",
                             },
                             "right": {"type": "string", "value": "'81003'"},
+                            "op_type": "bin_cmp_op",
                         },
+                        "op_type": "bin_logical_op",
                     },
                 },
                 "where": {},
@@ -484,7 +530,15 @@ sql_parser_tests = [
         [
             {
                 "type": "select_stmt",
-                "select_cols": [{'type': 'select_all_column', 'target': {'type': 'select_all_column', 'value': 'select_all_column'}}],
+                "select_cols": [
+                    {
+                        "type": "select_all_column",
+                        "target": {
+                            "type": "select_all_column",
+                            "value": "select_all_column",
+                        },
+                    }
+                ],
                 "table_names": ["tb1", "tb2"],
                 "where": {},
             }
@@ -497,14 +551,59 @@ sql_checker_tests = [
     ("create table person(name string, age int, classId int);", True),
     ("select age from person;", True),
     ("select * from person;", True),
-    ("select gender from person;", 'column `gender` not exists in `person`'),
+    ("select gender from person;", "column `gender` not exists in `person`"),
     ("select 123 from person;", True),
     ("drop table class;", True),
     ("create table class (id int, grade int, faculty string);", True),
     ("select * from class where grade = 2 and faculty = 'Computer Science';", True),
     (
         "select * from class where grade = 2 and count=33;",
-        'column `count` not exists in `class`',
+        "column `count` not exists in `class`",
     ),
     ("select age, class.grade from class, person;", True),
+    (
+        "select age, person.grade from class, person;",
+        "column `person.grade` not exists in `class, person`",
+    ),
+    (
+        """SELECT person.name, grade, faculty
+            FROM person
+            WHERE name = '张三' and classId IN (
+                SELECT *
+                FROM class
+                WHERE class.id = classId
+            );
+        """,
+        "column `grade` not exists in `person`",
+    ),
+    (
+        """select person.name 
+        from person join class 
+        on class.id=person.classId and person.grade=2
+        where age=22;
+        """,
+        "column `person.grade` not exists in `person, class`",
+    ),
+    (
+        """select person.name 
+        from person join class 
+        on class.id=person.classId and class.grade=2
+        where age=22;
+        """,
+        True,
+    ),
+    (
+        """select person.name 
+        from person join class 
+        on class.id=person.classId and class.grade = 'zxc'
+        where age=22;
+        """,
+        "column `class.grade` type is `int`, but `string:'zxc'` type is `string`, cannot `相等`",
+    ),
+    (
+        """select person.name 
+        from person where age>name;
+        """,
+        "column `age` type is `int`, but `name` type is `string`, cannot `大于`",
+    ),
 ]