|
@@ -0,0 +1,173 @@
|
|
|
+#include <fmt/core.h>
|
|
|
+#include <fmt/ranges.h>
|
|
|
+#include <stdio.h>
|
|
|
+
|
|
|
+#include <CLI/CLI.hpp>
|
|
|
+#include <algorithm>
|
|
|
+#include <limits>
|
|
|
+#include <optional>
|
|
|
+
|
|
|
+#include "utils.h"
|
|
|
+
|
|
|
+using json = nlohmann::json;
|
|
|
+using std::nullopt;
|
|
|
+using std::optional;
|
|
|
+using std::string;
|
|
|
+using std::vector;
|
|
|
+
|
|
|
+template <typename T = int>
|
|
|
+class Range {
|
|
|
+ public:
|
|
|
+ T up;
|
|
|
+ T down;
|
|
|
+
|
|
|
+ static T max() { return std::numeric_limits<T>::max(); }
|
|
|
+ static T min() { return std::numeric_limits<T>::min(); }
|
|
|
+
|
|
|
+ static Range<T> find_range(const json& j) {
|
|
|
+ Range<T> r{.up = std::numeric_limits<T>::min(),
|
|
|
+ .down = std::numeric_limits<T>::max()};
|
|
|
+ let type = type_of(j);
|
|
|
+
|
|
|
+ if (type == "小于") {
|
|
|
+ r.up = right_value(j);
|
|
|
+ r.down = std::numeric_limits<T>::min();
|
|
|
+ } else if (type == "大于") {
|
|
|
+ r.up = std::numeric_limits<T>::max();
|
|
|
+ r.down = right_value(j);
|
|
|
+ } else if (type == "相等") {
|
|
|
+ r.up = r.down = right_value(j);
|
|
|
+ }
|
|
|
+ return r;
|
|
|
+ }
|
|
|
+
|
|
|
+ Range<T> operator&(const Range<T>& right) const {
|
|
|
+ Range<T> new_range{.up = std::min(this->up, right.up),
|
|
|
+ .down = std::max(this->down, right.down)};
|
|
|
+ return new_range;
|
|
|
+ }
|
|
|
+ Range<T> operator|(const Range<T>& right) const {
|
|
|
+ Range<T> new_range{.up = std::max(this->up, right.up),
|
|
|
+ .down = std::min(this->down, right.down)};
|
|
|
+ return new_range;
|
|
|
+ }
|
|
|
+
|
|
|
+ bool empty() const { return this->down > this->up; }
|
|
|
+ bool full() const {
|
|
|
+ return this->down == std::numeric_limits<T>::min() &&
|
|
|
+ this->up == std::numeric_limits<T>::max();
|
|
|
+ }
|
|
|
+};
|
|
|
+
|
|
|
+json& check_compare_op(json& j) {
|
|
|
+ if (j.empty()) {
|
|
|
+ return j;
|
|
|
+ }
|
|
|
+ if (is_column(j)) {
|
|
|
+ return j;
|
|
|
+ }
|
|
|
+
|
|
|
+ let op_type = op_type_of(j);
|
|
|
+ let type = type_of(j);
|
|
|
+
|
|
|
+ auto& left = j["left"];
|
|
|
+ auto& right = j["right"];
|
|
|
+
|
|
|
+ left = check_compare_op(left);
|
|
|
+ right = check_compare_op(right);
|
|
|
+ let left_op_type = op_type_of(left);
|
|
|
+ let right_op_type = op_type_of(right);
|
|
|
+
|
|
|
+ if (op_type == OP_LOGIC && left_op_type == OP_COMPARE &&
|
|
|
+ right_op_type == OP_COMPARE) {
|
|
|
+ if (is_const_column(left["right"]) && is_const_column(right["right"])) {
|
|
|
+ // 说明有一个逻辑连接词, 两边都是比较, 并且比较用符号都相同,
|
|
|
+ // 并且比较的右侧都是常值
|
|
|
+ let left_range = Range<int>::find_range(left);
|
|
|
+ let right_range = Range<int>::find_range(right);
|
|
|
+ Range<int> new_range;
|
|
|
+
|
|
|
+ if (type == "且") {
|
|
|
+ new_range = left_range & right_range;
|
|
|
+ } else if (type == "或") {
|
|
|
+ new_range = left_range | right_range;
|
|
|
+ }
|
|
|
+ if (new_range.empty()) {
|
|
|
+ j = {{"type", "bool"}, {"value", false}};
|
|
|
+ } else if (new_range.full()) {
|
|
|
+ j = {{"type", "bool"}, {"value", true}};
|
|
|
+ } else {
|
|
|
+ if (new_range.max() == new_range.up &&
|
|
|
+ new_range.min() == new_range.min()) {
|
|
|
+ j["left"] = {{"type", "小于"},
|
|
|
+ {"left", j["left"]["left"]},
|
|
|
+ {"right", new_range.up}};
|
|
|
+ j["right"] = {{"type", "大于"},
|
|
|
+ {"left", j["right"]["left"]},
|
|
|
+ {"right", new_range.down}};
|
|
|
+
|
|
|
+ } else if (new_range.max() == new_range.up) {
|
|
|
+ j = {
|
|
|
+ {"type", "大于"},
|
|
|
+ {"left", j["left"]["left"]},
|
|
|
+ {"right", {{"type", "int"}, {"value", new_range.down}}},
|
|
|
+ {"op_type", "bin_cmp_op"}};
|
|
|
+ } else if (new_range.min() == new_range.down) {
|
|
|
+ j = {
|
|
|
+ {"type", "小于"},
|
|
|
+ {"left", j["left"]["left"]},
|
|
|
+ {"right", {{"type", "int"}, {"value", new_range.up}}},
|
|
|
+ {"op_type", "bin_cmp_op"}};
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return j;
|
|
|
+}
|
|
|
+
|
|
|
+void process_sql(json& j) {
|
|
|
+ let type = j.value("type", "none");
|
|
|
+
|
|
|
+ if (type != "select_stmt") {
|
|
|
+ throw std::runtime_error(
|
|
|
+ fmt::format("Unknown expression type `{}`\n", type));
|
|
|
+ }
|
|
|
+
|
|
|
+ j["where"] = check_compare_op(j["where"]);
|
|
|
+}
|
|
|
+
|
|
|
+int main(int argc, char** argv) {
|
|
|
+ CLI::App app{"sql 优化器"};
|
|
|
+ string sql = "", filename = "";
|
|
|
+
|
|
|
+ app.add_option("-f,--file", filename, "输入的 sql 解析后的 json 文件")
|
|
|
+ ->check(CLI::ExistingFile);
|
|
|
+ app.add_option("-s,--sql", sql, "输入的 sql");
|
|
|
+
|
|
|
+ CLI11_PARSE(app, argc, argv);
|
|
|
+
|
|
|
+ json stmts;
|
|
|
+ if (sql.length() > 0) {
|
|
|
+ stmts = parse_sql(sql).out;
|
|
|
+ } else if (filename.length() > 0) {
|
|
|
+ stmts = json::parse(std::ifstream(filename));
|
|
|
+ } else {
|
|
|
+ // std::cout << "请输入 sql 或者 sql 文件" << std::endl;
|
|
|
+ // return 1;
|
|
|
+ stmts = parse_sql(
|
|
|
+ "select * from person where (age < 18) or (age > 60 and "
|
|
|
+ "age > 65);")
|
|
|
+ .out;
|
|
|
+ }
|
|
|
+
|
|
|
+ for (auto& stmt : stmts) {
|
|
|
+ // for(auto& item : stmt.items()) {
|
|
|
+ // fmt::print("{}: {}\n", item.key(), item.value().dump());
|
|
|
+ // }
|
|
|
+ process_sql(stmt);
|
|
|
+ fmt::print("{}\n", stmt.dump(2));
|
|
|
+ }
|
|
|
+
|
|
|
+ return 0;
|
|
|
+}
|