|
@@ -15,6 +15,8 @@ using std::optional;
|
|
|
using std::string;
|
|
|
using std::vector;
|
|
|
|
|
|
+auto tables = ExistTables("./tables.json");
|
|
|
+
|
|
|
template <typename T = int>
|
|
|
class Range {
|
|
|
public:
|
|
@@ -147,7 +149,7 @@ json& check_logic_op(json& j) {
|
|
|
j = {{"type", "bool"}, {"value", left["value"] || right["value"]}};
|
|
|
}
|
|
|
} else if (left["type"] == "bool" || right["type"] == "bool") {
|
|
|
- if (left["type"] == "bool"){
|
|
|
+ if (left["type"] == "bool") {
|
|
|
let i = right;
|
|
|
right = left;
|
|
|
left = i;
|
|
@@ -157,9 +159,9 @@ json& check_logic_op(json& j) {
|
|
|
if (j["type"] == "或" && !right_v) {
|
|
|
j = left;
|
|
|
} else if (j["type"] == "且" && !right_v) {
|
|
|
- j = {{"type", "bool"}, {"value", false}};
|
|
|
+ j = {{"type", "bool"}, {"value", false}};
|
|
|
} else if (j["type"] == "或" && right_v) {
|
|
|
- j = {{"type", "bool"}, {"value", true}};
|
|
|
+ j = {{"type", "bool"}, {"value", true}};
|
|
|
} else if (j["type"] == "且" && right_v) {
|
|
|
j = left;
|
|
|
}
|
|
@@ -169,6 +171,82 @@ json& check_logic_op(json& j) {
|
|
|
return j;
|
|
|
}
|
|
|
|
|
|
+bool 谓词下推(json& total_stmt, json& parents_node, json& current_node, json& bro_node);
|
|
|
+
|
|
|
+bool 谓词下推_start(json& total_stmt) {
|
|
|
+ return 谓词下推(total_stmt, total_stmt["join_options"]["on"],
|
|
|
+ total_stmt["join_options"]["on"], total_stmt["join_options"]["on"]);
|
|
|
+}
|
|
|
+
|
|
|
+// 成功进行谓词下推后返回真
|
|
|
+bool 谓词下推(json& total_stmt, json& parents_node, json& current_node, json& bro_node) {
|
|
|
+ if (current_node.empty()) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if (is_column(current_node)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ // let op_type = op_type_of(current_node);
|
|
|
+ // if (op_type != OP_LOGIC) {
|
|
|
+ // return false;
|
|
|
+ // }
|
|
|
+ auto& left = current_node["left"];
|
|
|
+ auto& right = current_node["right"];
|
|
|
+
|
|
|
+ // fmt::print("left: {}\n", left.dump(2));
|
|
|
+ // fmt::print("right: {}\n", right.dump(2));
|
|
|
+
|
|
|
+ if (谓词下推(total_stmt, current_node, left, right)) {
|
|
|
+ return 谓词下推_start(total_stmt);
|
|
|
+ }
|
|
|
+ if (谓词下推(total_stmt, current_node, right, left)) {
|
|
|
+ return 谓词下推_start(total_stmt);
|
|
|
+ }
|
|
|
+
|
|
|
+ let left_col_name = select_column_get_name(left, true);
|
|
|
+ let right_col_name = select_column_get_name(right, true);
|
|
|
+
|
|
|
+ const vector<string> all_table_names = total_stmt["table_names"];
|
|
|
+
|
|
|
+ // fmt::print("bro_node: {} parents_node == current_node: {}\n", bro_node.dump(2), parents_node == current_node);
|
|
|
+
|
|
|
+ let left_res = tables.find_col_in_tables(all_table_names, left_col_name);
|
|
|
+ let right_res = tables.find_col_in_tables(all_table_names, right_col_name);
|
|
|
+
|
|
|
+ // 两个都在目标表中 -> 可以进行谓词下推
|
|
|
+ // 只有一个在目标表中 && 另一个是常量 -> 可以进行谓词下推
|
|
|
+ if(!left_res.has_value() && !right_res.has_value()) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if(!left_res.has_value() && !is_const_column(left)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+ if(!right_res.has_value() && !is_const_column(right)) {
|
|
|
+ return false;
|
|
|
+ }
|
|
|
+
|
|
|
+ if (left_res.has_value()) {
|
|
|
+ total_stmt["where"] = current_node;
|
|
|
+ if (parents_node == current_node) {
|
|
|
+ parents_node = json::object();
|
|
|
+ } else {
|
|
|
+ parents_node = bro_node;
|
|
|
+ }
|
|
|
+
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+ if (right_res.has_value()) {
|
|
|
+ total_stmt["where"] = current_node;
|
|
|
+ if (parents_node == current_node) {
|
|
|
+ parents_node = json::object();
|
|
|
+ } else {
|
|
|
+ parents_node = bro_node;
|
|
|
+ }
|
|
|
+ return true;
|
|
|
+ }
|
|
|
+
|
|
|
+ return false;
|
|
|
+}
|
|
|
void process_sql(json& j) {
|
|
|
let type = j.value("type", "none");
|
|
|
|
|
@@ -179,6 +257,9 @@ void process_sql(json& j) {
|
|
|
|
|
|
j["where"] = check_compare_op(j["where"]);
|
|
|
j["where"] = check_logic_op(j["where"]);
|
|
|
+ if (j.contains("join_options")) {
|
|
|
+ 谓词下推_start(j);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
int main(int argc, char** argv) {
|
|
@@ -200,7 +281,9 @@ int main(int argc, char** argv) {
|
|
|
// std::cout << "请输入 sql 或者 sql 文件" << std::endl;
|
|
|
// return 1;
|
|
|
stmts = parse_sql(
|
|
|
- "select * from person where (age < 18) or (age > 60 and age < 35);")
|
|
|
+ "select person.name "
|
|
|
+ "from person join class "
|
|
|
+ "on age=22 and class.id=person.classId;")
|
|
|
.out;
|
|
|
}
|
|
|
|