Browse Source

check where

myuan 2 years ago
parent
commit
3d25c98cd2
3 changed files with 40 additions and 11 deletions
  1. 33 10
      src/checker.cpp
  2. 2 0
      src/utils.h
  3. 5 1
      tests_config.py

+ 33 - 10
src/checker.cpp

@@ -24,8 +24,8 @@ void create_table(const json& j) {
     tables.save();
 }
 
-optional<TableCol> get_column_define(const string table_name,
-                                     const json& select_col) {
+optional<TableCol> get_select_column_define(const string table_name,
+                                           const json& select_col) {
     auto type = type_of(select_col);
     if (type == "identifier") {
         auto cols = tables[table_name];
@@ -50,26 +50,48 @@ optional<TableCol> get_column_define(const string table_name,
     return nullopt;
 }
 
-void check_select_table(const json& j) {
-    auto table_name = table_name_of(j);
-    auto select_cols = j["select_cols"];
+void check_where_clause(const string table_name, const json& j) {
+    if (j.empty()) {
+        return;
+    }
+
+    auto type = type_of(j);
+    auto left = j["left"];
+    auto right = j["right"];
+    for(const auto curr_branch_name : {"left", "right"}) {
+        auto curr_branch = j[curr_branch_name];
+        if (left.contains("value")) {
+            // 说明到叶节点了
+            auto col_def = get_select_column_define(table_name, curr_branch);
+            if (!col_def.has_value()) {
+                throw std::runtime_error(fmt::format(
+                    "column `{}` not exists in `{}`\n",
+                    curr_branch["value"].dump(2), table_name));
+            }
+        } else {
+            check_where_clause(table_name, curr_branch);
+        }
+    }
+}
 
+void check_select_table(const string table_name, const json& select_cols) {
     for (auto& select_col : select_cols) {
         if (select_col["type"] == "select_all_column") {
             continue;
         }
-        auto col_def = get_column_define(table_name, select_col["target"]);
+        auto col_def =
+            get_select_column_define(table_name, select_col["target"]);
 
         if (!col_def.has_value()) {
             throw std::runtime_error(
                 fmt::format("column `{}` not exists in `{}`\n",
-                            select_col["target"].dump(2), table_name));
+                            select_col["target"]["value"].dump(2), table_name));
         }
     }
 }
 
 void process_sql(json& j) {
-    auto type = j.value("type", "none");
+    let type = j.value("type", "none");
 
     if (type == "create_stmt") {
         // 创建表
@@ -84,14 +106,15 @@ void process_sql(json& j) {
         }
 
         // 其余都会要求表存在
-        auto table_name = table_name_of(j);
+        let table_name = table_name_of(j);
         if (!tables.exists(table_name)) {
             fmt::print("error: table `{}` does not exist\n", table_name);
             return;
         }
 
         if (type == "select_stmt") {
-            check_select_table(j);
+            check_select_table(table_name, j["select_cols"]);
+            check_where_clause(table_name, j["where"]);
         } else {
             throw std::runtime_error(fmt::format(
                 "Unknown expression type `{}` total: \n{}", type, j.dump(2)));

+ 2 - 0
src/utils.h

@@ -4,6 +4,8 @@
 #include <nlohmann/json.hpp>
 #include <vector>
 
+#define let const auto
+
 struct SQLParserRes {
     int exit_code;
     nlohmann::json out;

+ 5 - 1
tests_config.py

@@ -490,9 +490,13 @@ sql_parser_tests = [
 
 sql_checker_tests = [
     ('drop table person;', True),
-    ('create table person(name string, age int);', True),
+    ('create table person(name string, age int, classId int);', True),
     ('select age from person;', True),
     ('select * from person;', True),
     ('select gender from person;', 'column `"gender"` not exists in `person`'),
     ('select 123 from person;', True),
+
+    ('drop table class;', True),
+    ('create table class (id int, grade int, faculty string);', True),
+    ('select * from class where grade = 2 and faculty = \'Computer Science\';', True),
 ]