Forráskód Böngészése

将select相关表检查改为vector

myuan 2 éve
szülő
commit
7b2ba25f97
3 módosított fájl, 81 hozzáadás és 46 törlés
  1. 37 43
      src/checker.cpp
  2. 30 0
      src/utils.cpp
  3. 14 3
      src/utils.h

+ 37 - 43
src/checker.cpp

@@ -1,4 +1,5 @@
 #include <fmt/core.h>
+#include <fmt/ranges.h>
 #include <stdio.h>
 
 #include <CLI/CLI.hpp>
@@ -10,6 +11,7 @@ using json = nlohmann::json;
 using std::nullopt;
 using std::optional;
 using std::string;
+using std::vector;
 
 auto tables = ExistTables("./tables.json");
 
@@ -24,33 +26,26 @@ void create_table(const json& j) {
     tables.save();
 }
 
-optional<TableCol> get_select_column_define(const string table_name,
-                                           const json& select_col) {
+optional<TableCol> get_select_column_define(const vector<string> table_names,
+                                            const json& select_col) {
     auto type = type_of(select_col);
     if (type == "identifier") {
-        auto cols = tables[table_name];
-        string colname = select_col["value"];
-        for (auto& col : cols) {
-            if (col.column_name == colname) {
-                return col;
-            }
-        }
-        return nullopt;
+        return tables.find_col_in_tables(table_names, select_col["value"]);
     } else if (type == "int" || type == "string" || type == "float") {
         return TableCol{.column_name = select_col["value"].dump(),
-                        .data_type = type,
                         .type = "const",
+                        .data_type = type,
                         .primary_key = false};
     } else if (type == "select_all_column") {
         return TableCol{.column_name = "*",
-                        .data_type = "",
                         .type = "select_all_column",
+                        .data_type = "",
                         .primary_key = false};
     }
     return nullopt;
 }
 
-void check_where_clause(const string table_name, const json& j) {
+void check_where_clause(const vector<string> table_names, const json& j) {
     if (j.empty()) {
         return;
     }
@@ -58,34 +53,37 @@ void check_where_clause(const string table_name, const json& j) {
     auto type = type_of(j);
     auto left = j["left"];
     auto right = j["right"];
-    for(const auto curr_branch_name : {"left", "right"}) {
+    for (const auto curr_branch_name : {"left", "right"}) {
         auto curr_branch = j[curr_branch_name];
         if (left.contains("value")) {
             // 说明到叶节点了
-            auto col_def = get_select_column_define(table_name, curr_branch);
+            auto col_def = get_select_column_define(table_names, curr_branch);
             if (!col_def.has_value()) {
-                throw std::runtime_error(fmt::format(
-                    "column `{}` not exists in `{}`\n",
-                    curr_branch["value"].dump(2), table_name));
+                throw std::runtime_error(
+                    fmt::format("column `{}` not exists in `{}`\n",
+                                curr_branch["value"].dump(2),
+                                fmt::join(table_names, ", ")));
             }
         } else {
-            check_where_clause(table_name, curr_branch);
+            check_where_clause(table_names, curr_branch);
         }
     }
 }
 
-void check_select_table(const string table_name, const json& select_cols) {
+void check_select_table(const vector<string> table_names,
+                        const json& select_cols) {
     for (auto& select_col : select_cols) {
         if (select_col["type"] == "select_all_column") {
             continue;
         }
         auto col_def =
-            get_select_column_define(table_name, select_col["target"]);
+            get_select_column_define(table_names, select_col["target"]);
 
         if (!col_def.has_value()) {
             throw std::runtime_error(
                 fmt::format("column `{}` not exists in `{}`\n",
-                            select_col["target"]["value"].dump(2), table_name));
+                            select_col["target"]["value"].dump(2),
+                            fmt::join(table_names, ", ")));
         }
     }
 }
@@ -96,29 +94,25 @@ void process_sql(json& j) {
     if (type == "create_stmt") {
         // 创建表
         create_table(j);
-    } else {
-        if (type == "drop_stmt") {
-            if (tables.exists(j["table_name"])) {
-                tables.remove(j["table_name"]);
-                tables.save();
-            }
-            return;
+    } else if (type == "drop_stmt") {
+        if (tables.exists(j["table_name"])) {
+            tables.remove(j["table_name"]);
+            tables.save();
         }
-
-        // 其余都会要求表存在
-        let table_name = table_name_of(j);
-        if (!tables.exists(table_name)) {
-            fmt::print("error: table `{}` does not exist\n", table_name);
-            return;
-        }
-
-        if (type == "select_stmt") {
-            check_select_table(table_name, j["select_cols"]);
-            check_where_clause(table_name, j["where"]);
-        } else {
-            throw std::runtime_error(fmt::format(
-                "Unknown expression type `{}` total: \n{}", type, j.dump(2)));
+        return;
+    } else if (type == "select_stmt") {
+        let table_names = table_names_of(j);
+        for (let table_name : table_names) {
+            if (!tables.exists(table_name)) {
+                fmt::print("error: table `{}` does not exist\n", table_name);
+                return;
+            }
         }
+        check_select_table(table_names, j["select_cols"]);
+        check_where_clause(table_names, j["where"]);
+    } else {
+        throw std::runtime_error(fmt::format(
+            "Unknown expression type `{}` total: \n{}", type, j.dump(2)));
     }
 }
 

+ 30 - 0
src/utils.cpp

@@ -2,6 +2,7 @@
 
 #include <fmt/core.h>
 
+#include <algorithm>
 #include <fstream>
 #include <map>
 #include <nlohmann/json.hpp>
@@ -61,8 +62,37 @@ void ExistTables::set(const std::string& table_name, const TableCols& cols) {
     tables[table_name] = cols;
 }
 TableCols ExistTables::operator[](const std::string& table_name) {
+    return this->get(table_name);
+}
+TableCols ExistTables::operator[](const std::vector<std::string>& table_names) {
+    return this->get(table_names);
+}
+TableCols ExistTables::get(const std::string& table_name) {
     return tables[table_name];
 }
+TableCols ExistTables::get(const std::vector<std::string>& table_names) {
+    TableCols res = {};
+    for (let table_name : table_names) {
+        let cols = this->get(table_name);
+        res.assign(cols.begin(), cols.end());
+    }
+    return res;
+}
+
+optional<TableCol> ExistTables::find_col_in_tables(
+        const std::vector<std::string>& table_names,
+        const std::string col_name) {
+    for (let table_name : table_names) {
+        let cols = this->get(table_name);
+        for(let col : cols) {
+            if (col.column_name == col_name) {
+                return col;
+            }
+        }
+    }
+    return nullopt;
+}
+
 void ExistTables::remove(const std::string& table_name) {
     tables.erase(table_name);
 }

+ 14 - 3
src/utils.h

@@ -2,6 +2,7 @@
 
 #include <map>
 #include <nlohmann/json.hpp>
+#include <optional>
 #include <vector>
 
 #define let const auto
@@ -22,6 +23,8 @@ struct TableCol {
 };
 
 using TableCols = std::vector<TableCol>;
+using std::nullopt;
+using std::optional;
 
 SQLParserRes parse_sql(const std::string& sql);
 
@@ -35,6 +38,13 @@ class ExistTables {
     bool exists(const std::string& table_name);
     void set(const std::string& table_name, const TableCols& cols);
     TableCols operator[](const std::string& table_name);
+    TableCols operator[](const std::vector<std::string>& table_names);
+    TableCols get(const std::string& table_name);
+    TableCols get(const std::vector<std::string>& table_names);
+
+    optional<TableCol> find_col_in_tables(
+        const std::vector<std::string>& table_names,
+        const std::string col_name);
     void remove(const std::string& table_name);
     void save();
 };
@@ -42,7 +52,8 @@ class ExistTables {
 inline auto table_name_of(const nlohmann::json& j) {
     return j.value("table_name", "none");
 }
-inline auto type_of(const nlohmann::json& j) {
-    return j.value("type", "none");
+inline auto table_names_of(const nlohmann::json& j) {
+    // return j.value("table_names", json::array());
+    return j["table_names"];
 }
-
+inline auto type_of(const nlohmann::json& j) { return j.value("type", "none"); }