myuan vor 2 Jahren
Ursprung
Commit
2adfa2df13
6 geänderte Dateien mit 192 neuen und 13 gelöschten Zeilen
  1. 3 3
      run_sql_parser_test.py
  2. 87 4
      src/optimizer.cpp
  3. 23 5
      src/utils.cpp
  4. 1 1
      src/utils.h
  5. 7 0
      test_configs/sql_checker.py
  6. 71 0
      test_configs/sql_optimizer.py

+ 3 - 3
run_sql_parser_test.py

@@ -123,14 +123,14 @@ async def assert_sql_optimizer_check():
         try:
             output = orjson.loads(stdout)
             if output != res:
-                print(colored(res, "yellow"))
-                print(colored(output, "red"))
+                print('', colored(res, "yellow"))
+                print('real', colored(output, "red"))
                 print(colored(stdout.decode("utf-8"), "yellow"))
                 print(colored(stderr.decode("utf-8"), "yellow"))
                 assert False
-
         except Exception as e:
             print(e)
+            break
 
 
 async def on_optimizer_modified():

+ 87 - 4
src/optimizer.cpp

@@ -15,6 +15,8 @@ using std::optional;
 using std::string;
 using std::vector;
 
+auto tables = ExistTables("./tables.json");
+
 template <typename T = int>
 class Range {
    public:
@@ -147,7 +149,7 @@ json& check_logic_op(json& j) {
             j = {{"type", "bool"}, {"value", left["value"] || right["value"]}};
         }
     } else if (left["type"] == "bool" || right["type"] == "bool") {
-        if (left["type"] == "bool"){
+        if (left["type"] == "bool") {
             let i = right;
             right = left;
             left = i;
@@ -157,9 +159,9 @@ json& check_logic_op(json& j) {
         if (j["type"] == "或" && !right_v) {
             j = left;
         } else if (j["type"] == "且" && !right_v) {
-            j =  {{"type", "bool"}, {"value", false}};
+            j = {{"type", "bool"}, {"value", false}};
         } else if (j["type"] == "或" && right_v) {
-            j =  {{"type", "bool"}, {"value", true}};
+            j = {{"type", "bool"}, {"value", true}};
         } else if (j["type"] == "且" && right_v) {
             j = left;
         }
@@ -169,6 +171,82 @@ json& check_logic_op(json& j) {
     return j;
 }
 
+bool 谓词下推(json& total_stmt, json& parents_node, json& current_node, json& bro_node);
+
+bool 谓词下推_start(json& total_stmt) {
+    return 谓词下推(total_stmt, total_stmt["join_options"]["on"],
+                    total_stmt["join_options"]["on"], total_stmt["join_options"]["on"]);
+}
+
+// 成功进行谓词下推后返回真
+bool 谓词下推(json& total_stmt, json& parents_node, json& current_node, json& bro_node) {
+    if (current_node.empty()) {
+        return false;
+    }
+    if (is_column(current_node)) {
+        return false;
+    }
+    // let op_type = op_type_of(current_node);
+    // if (op_type != OP_LOGIC) {
+    //     return false;
+    // }
+    auto& left = current_node["left"];
+    auto& right = current_node["right"];
+
+    // fmt::print("left: {}\n", left.dump(2));
+    // fmt::print("right: {}\n", right.dump(2));
+
+    if (谓词下推(total_stmt, current_node, left, right)) {
+        return 谓词下推_start(total_stmt);
+    }
+    if (谓词下推(total_stmt, current_node, right, left)) {
+        return 谓词下推_start(total_stmt);
+    }
+
+    let left_col_name = select_column_get_name(left, true);
+    let right_col_name = select_column_get_name(right, true);
+
+    const vector<string> all_table_names = total_stmt["table_names"];
+
+    // fmt::print("bro_node: {} parents_node == current_node: {}\n", bro_node.dump(2), parents_node == current_node);
+
+    let left_res = tables.find_col_in_tables(all_table_names, left_col_name);
+    let right_res = tables.find_col_in_tables(all_table_names, right_col_name);
+
+    // 两个都在目标表中 -> 可以进行谓词下推
+    // 只有一个在目标表中 && 另一个是常量 -> 可以进行谓词下推
+    if(!left_res.has_value() && !right_res.has_value()) {
+        return false;
+    }
+    if(!left_res.has_value() && !is_const_column(left)) {
+        return false;
+    }
+    if(!right_res.has_value() && !is_const_column(right)) {
+        return false;
+    }
+
+    if (left_res.has_value()) {
+        total_stmt["where"] = current_node;
+        if (parents_node == current_node) {
+            parents_node = json::object();
+        } else {
+            parents_node = bro_node;
+        }
+
+        return true;
+    }
+    if (right_res.has_value()) {
+        total_stmt["where"] = current_node;
+        if (parents_node == current_node) {
+            parents_node = json::object();
+        } else {
+            parents_node = bro_node;
+        }
+        return true;
+    }
+
+    return false;
+}
 void process_sql(json& j) {
     let type = j.value("type", "none");
 
@@ -179,6 +257,9 @@ void process_sql(json& j) {
 
     j["where"] = check_compare_op(j["where"]);
     j["where"] = check_logic_op(j["where"]);
+    if (j.contains("join_options")) {
+        谓词下推_start(j);
+    }
 }
 
 int main(int argc, char** argv) {
@@ -200,7 +281,9 @@ int main(int argc, char** argv) {
         // std::cout << "请输入 sql 或者 sql 文件" << std::endl;
         // return 1;
         stmts = parse_sql(
-                    "select * from person where (age < 18) or (age > 60 and age < 35);")
+                    "select person.name "
+                    "from person join class "
+                    "on age=22 and class.id=person.classId;")
                     .out;
     }
 

+ 23 - 5
src/utils.cpp

@@ -135,19 +135,37 @@ optional<TableCol> ExistTables::get_select_column_define(
     return nullopt;
 }
 
-string select_column_get_name(const json& j) {
+string select_column_get_name(const json& j, bool only_center_name) {
     let type = j["type"];
+    // fmt::print("jdump: {}\n", j.dump());
+
     if (type == "identifier") {
         return j["value"];
     } else if (type == "select_all_column") {
         return "select_all_column";
     } else if (type == "int" || type == "string" || type == "float") {
-        return fmt::format("{}:{}", type, j["value"]);
+        string temp_v;
+        if (type != "string") {
+            temp_v = j["value"].dump();
+        } else {
+            temp_v = j["value"];
+        }
+
+        if (only_center_name)
+            return fmt::format("{}", temp_v);
+        else
+            return fmt::format("{}:{}", type, temp_v);
     } else if (type == "table_field") {
-        return fmt::format("{}.{}", j["table"], j["field"]);
+        if (only_center_name)
+            return j["field"];
+        else
+            return fmt::format("{}.{}", j["table"], j["field"]);
     }
-    throw std::runtime_error(
-        fmt::format("无法将 json 视为 select_column: {}", j.dump(2)));
+    return "";
+    // if (!only_center_name) {
+    //     throw std::runtime_error(
+    //         fmt::format("无法将 json 视为 select_column: {}", j.dump(2)));
+    // }
 }
 
 bool is_column(const json& j) {

+ 1 - 1
src/utils.h

@@ -65,7 +65,7 @@ inline nlohmann::json right_value(const nlohmann::json& j) {
     return j["right"]["value"];
 }
 
-std::string select_column_get_name(const nlohmann::json& j);
+std::string select_column_get_name(const nlohmann::json& j, bool only_center_name = false);
 
 bool is_column(const nlohmann::json& j);
 bool is_const_column(const nlohmann::json& j);

+ 7 - 0
test_configs/sql_checker.py

@@ -58,5 +58,12 @@ sql_checker_tests = [
         """,
         "column `age` type is `int`, but `name` type is `string`, cannot `大于`",
     ),
+    (
+        """select person.name 
+        from person join class 
+        on age=22 and class.id=person.classId ;
+        """,
+        True
+    )
 ]
 

+ 71 - 0
test_configs/sql_optimizer.py

@@ -60,4 +60,75 @@ sql_optimizer_tests = [
             "where": {"type": "bool", "value": False},
         },
     ),
+    (
+        """select person.name 
+        from person join class 
+        on age=22 and class.id=person.classId ;
+        """,
+        {
+            "join_options": {
+                "join_with": {"type": "identifier", "value": "class"},
+                "on": {
+                    "left": {"field": "id", "table": "class", "type": "table_field"},
+                    "op_type": "bin_cmp_op",
+                    "right": {
+                        "field": "classId",
+                        "table": "person",
+                        "type": "table_field",
+                    },
+                    "type": "相等",
+                },
+                "type": "join_options",
+            },
+            "select_cols": [
+                {
+                    "target": {
+                        "field": "name",
+                        "table": "person",
+                        "type": "table_field",
+                    },
+                    "type": "select_column",
+                }
+            ],
+            "table_names": ["person"],
+            "type": "select_stmt",
+            "where": {
+                "left": {"type": "identifier", "value": "age"},
+                "op_type": "bin_cmp_op",
+                "right": {"type": "int", "value": 22},
+                "type": "相等",
+            },
+        },
+    ),
+    (
+        """select person.name 
+        from person join class 
+        on age=22;
+        """,
+        {
+            "join_options": {
+                "join_with": {"type": "identifier", "value": "class"},
+                "on": {},
+                "type": "join_options",
+            },
+            "select_cols": [
+                {
+                    "target": {
+                        "field": "name",
+                        "table": "person",
+                        "type": "table_field",
+                    },
+                    "type": "select_column",
+                }
+            ],
+            "table_names": ["person"],
+            "type": "select_stmt",
+            "where": {
+                "left": {"type": "identifier", "value": "age"},
+                "op_type": "bin_cmp_op",
+                "right": {"type": "int", "value": 22},
+                "type": "相等",
+            },
+        },
+    ),
 ]