瀏覽代碼

加入区间优化

myuan 2 年之前
父節點
當前提交
23b5938017
共有 6 個文件被更改,包括 266 次插入14 次删除
  1. 39 3
      run_sql_parser_test.py
  2. 173 0
      src/optimizer.cpp
  3. 11 5
      src/utils.cpp
  4. 8 4
      src/utils.h
  5. 25 0
      tests_config.py
  6. 10 2
      xmake.lua

+ 39 - 3
run_sql_parser_test.py

@@ -12,9 +12,10 @@ import importlib
 
 importlib.reload(tests_config)
 
-sql_parser_tests, sql_checker_tests = (
+sql_parser_tests, sql_checker_tests, sql_optimizer_tests = (
     tests_config.sql_parser_tests,
     tests_config.sql_checker_tests,
+    tests_config.sql_optimizer_tests
 )
 
 
@@ -111,6 +112,30 @@ async def on_checker_modified():
     else:
         print(datetime.now(), colored("all checker tests right!", "green"))
 
+async def assert_sql_optimizer_check():
+    for sql, res in sql_optimizer_tests:
+        stdout, stderr = await run_and_output("xmake", "run", "sql-optimizer", "-s", sql)
+        print(sql, res)
+
+        try:
+            output = orjson.loads(stdout)
+            if output != res:
+                print(stdout.decode("utf-8") + "\n" + stderr.decode("utf-8"))
+                assert False
+            
+        except Exception as e:
+            print(e)
+
+
+
+async def on_optimizer_modified():
+    print(datetime.now(), colored("run optimizer tests...", "yellow"))
+    try:
+        await assert_sql_optimizer_check()
+    except Exception as e:
+        print(colored(e, "red"))
+    else:
+        print(datetime.now(), colored("all optimizer tests right!", "green"))
 
 async def restart():
     async for _ in awatch(__file__, "./tests_config.py"):
@@ -131,14 +156,25 @@ async def watch_checker():
         if await rebuild():
             await on_checker_modified()
 
+async def watch_optimizer():
+    async for changes in awatch("src/optimizer.cpp"):
+        if await rebuild():
+            await on_optimizer_modified()
+
+async def rerun_tests():
+    await asyncio.gather(
+        on_parser_modified(),
+        on_checker_modified(),
+    )
+    await on_optimizer_modified()
 
 async def main():
     await asyncio.gather(
         restart(),
         watch_parser(),
         watch_checker(),
-        on_parser_modified(),
-        on_checker_modified(),
+        watch_optimizer(), 
+        rerun_tests(),
     )
 
 

+ 173 - 0
src/optimizer.cpp

@@ -0,0 +1,173 @@
+#include <fmt/core.h>
+#include <fmt/ranges.h>
+#include <stdio.h>
+
+#include <CLI/CLI.hpp>
+#include <algorithm>
+#include <limits>
+#include <optional>
+
+#include "utils.h"
+
+using json = nlohmann::json;
+using std::nullopt;
+using std::optional;
+using std::string;
+using std::vector;
+
+template <typename T = int>
+class Range {
+   public:
+    T up;
+    T down;
+
+    static T max() { return std::numeric_limits<T>::max(); }
+    static T min() { return std::numeric_limits<T>::min(); }
+
+    static Range<T> find_range(const json& j) {
+        Range<T> r{.up = std::numeric_limits<T>::min(),
+                   .down = std::numeric_limits<T>::max()};
+        let type = type_of(j);
+
+        if (type == "小于") {
+            r.up = right_value(j);
+            r.down = std::numeric_limits<T>::min();
+        } else if (type == "大于") {
+            r.up = std::numeric_limits<T>::max();
+            r.down = right_value(j);
+        } else if (type == "相等") {
+            r.up = r.down = right_value(j);
+        }
+        return r;
+    }
+
+    Range<T> operator&(const Range<T>& right) const {
+        Range<T> new_range{.up = std::min(this->up, right.up),
+                           .down = std::max(this->down, right.down)};
+        return new_range;
+    }
+    Range<T> operator|(const Range<T>& right) const {
+        Range<T> new_range{.up = std::max(this->up, right.up),
+                           .down = std::min(this->down, right.down)};
+        return new_range;
+    }
+
+    bool empty() const { return this->down > this->up; }
+    bool full() const {
+        return this->down == std::numeric_limits<T>::min() &&
+               this->up == std::numeric_limits<T>::max();
+    }
+};
+
+json& check_compare_op(json& j) {
+    if (j.empty()) {
+        return j;
+    }
+    if (is_column(j)) {
+        return j;
+    }
+
+    let op_type = op_type_of(j);
+    let type = type_of(j);
+
+    auto& left = j["left"];
+    auto& right = j["right"];
+
+    left = check_compare_op(left);
+    right = check_compare_op(right);
+    let left_op_type = op_type_of(left);
+    let right_op_type = op_type_of(right);
+
+    if (op_type == OP_LOGIC && left_op_type == OP_COMPARE &&
+        right_op_type == OP_COMPARE) {
+        if (is_const_column(left["right"]) && is_const_column(right["right"])) {
+            // 说明有一个逻辑连接词, 两边都是比较, 并且比较用符号都相同,
+            // 并且比较的右侧都是常值
+            let left_range = Range<int>::find_range(left);
+            let right_range = Range<int>::find_range(right);
+            Range<int> new_range;
+
+            if (type == "且") {
+                new_range = left_range & right_range;
+            } else if (type == "或") {
+                new_range = left_range | right_range;
+            }
+            if (new_range.empty()) {
+                j = {{"type", "bool"}, {"value", false}};
+            } else if (new_range.full()) {
+                j = {{"type", "bool"}, {"value", true}};
+            } else {
+                if (new_range.max() == new_range.up &&
+                    new_range.min() == new_range.min()) {
+                    j["left"] = {{"type", "小于"},
+                                 {"left", j["left"]["left"]},
+                                 {"right", new_range.up}};
+                    j["right"] = {{"type", "大于"},
+                                  {"left", j["right"]["left"]},
+                                  {"right", new_range.down}};
+
+                } else if (new_range.max() == new_range.up) {
+                    j = {
+                        {"type", "大于"},
+                        {"left", j["left"]["left"]},
+                        {"right", {{"type", "int"}, {"value", new_range.down}}},
+                        {"op_type", "bin_cmp_op"}};
+                } else if (new_range.min() == new_range.down) {
+                    j = {
+                        {"type", "小于"},
+                        {"left", j["left"]["left"]},
+                        {"right", {{"type", "int"}, {"value", new_range.up}}},
+                        {"op_type", "bin_cmp_op"}};
+                }
+            }
+        }
+    }
+
+    return j;
+}
+
+void process_sql(json& j) {
+    let type = j.value("type", "none");
+
+    if (type != "select_stmt") {
+        throw std::runtime_error(
+            fmt::format("Unknown expression type `{}`\n", type));
+    }
+
+    j["where"] = check_compare_op(j["where"]);
+}
+
+int main(int argc, char** argv) {
+    CLI::App app{"sql 优化器"};
+    string sql = "", filename = "";
+
+    app.add_option("-f,--file", filename, "输入的 sql 解析后的 json 文件")
+        ->check(CLI::ExistingFile);
+    app.add_option("-s,--sql", sql, "输入的 sql");
+
+    CLI11_PARSE(app, argc, argv);
+
+    json stmts;
+    if (sql.length() > 0) {
+        stmts = parse_sql(sql).out;
+    } else if (filename.length() > 0) {
+        stmts = json::parse(std::ifstream(filename));
+    } else {
+        // std::cout << "请输入 sql 或者 sql 文件" << std::endl;
+        // return 1;
+        stmts = parse_sql(
+                    "select * from person where (age < 18) or (age > 60 and "
+                    "age > 65);")
+                    .out;
+    }
+
+    for (auto& stmt : stmts) {
+        // for(auto& item : stmt.items()) {
+        //     fmt::print("{}: {}\n", item.key(), item.value().dump());
+        // }
+        process_sql(stmt);
+        fmt::print("{}\n", stmt.dump(2));
+    }
+
+    return 0;
+}

+ 11 - 5
src/utils.cpp

@@ -13,14 +13,15 @@ using json = nlohmann::json;
 using std::string;
 using std::vector;
 
-SQLParserRes parse_sql(const std::string& sql) {
+SQLParserRes parse_sql(const std::string& sql,
+                       const string& sql_parser_prefix) {
     SQLParserRes res;
     // auto stdout_name = std::tmpnam(nullptr);
     // auto stderr_name = std::tmpnam(nullptr);
     auto stdout_name = "/tmp/sql-parser-stdout";
     auto stderr_name = "/tmp/sql-parser-stderr";
 
-    auto cmd = fmt::format("xmake run sql-parser \"{}\" > {} 2> {}", sql,
+    auto cmd = fmt::format("{} \"{}\" > {} 2> {}", sql_parser_prefix, sql,
                            stdout_name, stderr_name);
     res.exit_code = system(cmd.c_str());
 
@@ -151,13 +152,18 @@ string select_column_get_name(const json& j) {
 
 bool is_column(const json& j) {
     let type = j.value("type", "none");
-    return type == "identifier" || type == "int" || type == "string" ||
-           type == "float" || type == "select_all_column" ||
-           type == "table_field";
+    return type == "identifier" || is_const_column(j) ||
+           type == "select_all_column" || type == "table_field";
+}
+
+bool is_const_column(const json& j) {
+    let type = j.value("type", "none");
+    return type == "int" || type == "string" || type == "float";
 }
 
 OpType op_type_of(const nlohmann::json& j) {
     // bin_contains_op, bin_logical_op, bin_cmp_op
+    if (j.empty()) return OpType::OP_UNKNOWN;
 
     let type = j.value("op_type", "none");
     if (type == "bin_contains_op") {

+ 8 - 4
src/utils.h

@@ -26,8 +26,8 @@ using TableCols = std::vector<TableCol>;
 using std::nullopt;
 using std::optional;
 
-SQLParserRes parse_sql(const std::string& sql);
-
+SQLParserRes parse_sql(const std::string& sql, const std::string& sql_parser_prefix =
+                                                   "xmake run sql-parser");
 class ExistTables {
    public:
     std::string table_file_name;
@@ -60,11 +60,15 @@ inline auto table_names_of(const nlohmann::json& j) {
     // return j.value("table_names", json::array());
     return j["table_names"];
 }
-inline auto type_of(const nlohmann::json& j) { return j.value("type", "none"); }
+inline std::string type_of(const nlohmann::json& j) { return j.value("type", "none"); }
+inline nlohmann::json right_value(const nlohmann::json& j) {
+    return j["right"]["value"];
+}
 
 std::string select_column_get_name(const nlohmann::json& j);
 
 bool is_column(const nlohmann::json& j);
+bool is_const_column(const nlohmann::json& j);
 
 enum OpType {
     OP_COMPARE,
@@ -73,4 +77,4 @@ enum OpType {
     OP_UNKNOWN,
 };
 
-OpType op_type_of(const nlohmann::json& j);
+OpType op_type_of(const nlohmann::json& j);

+ 25 - 0
tests_config.py

@@ -607,3 +607,28 @@ sql_checker_tests = [
         "column `age` type is `int`, but `name` type is `string`, cannot `大于`",
     ),
 ]
+
+sql_optimizer_tests = [
+    (
+        "select * from person where (age < 18) or (age > 60 and age < 35);",
+        {
+            "select_cols": [
+                {
+                    "target": {
+                        "type": "select_all_column",
+                        "value": "select_all_column",
+                    },
+                    "type": "select_all_column",
+                }
+            ],
+            "table_names": ["person"],
+            "type": "select_stmt",
+            "where": {
+                "left": {"type": "identifier", "value": "age"},
+                "op_type": "bin_cmp_op",
+                "right": {"type": "int", "value": 18},
+                "type": "小于",
+            },
+        },
+    ),
+]

+ 10 - 2
xmake.lua

@@ -17,12 +17,20 @@ target("sql-checker")
     set_toolset("cxx", "clang++")
 
     set_kind("binary")
-    add_files("src/*.cpp")
+    add_files("src/utils.cpp", "src/checker.cpp")
     add_includedirs("src")
     add_packages("fmt", "nlohmann_json", "cli11")
-    -- add_cxxflags("-ftime-trace", {force = true})
     -- set_optimize("none")
+target_end()
+
+target("sql-optimizer")
+    set_languages("c++20")
+    set_toolset("cxx", "clang++")
 
+    set_kind("binary")
+    add_files("src/utils.cpp", "src/optimizer.cpp")
+    add_includedirs("src")
+    add_packages("fmt", "nlohmann_json", "cli11")
 target_end()
 
 --