1
0

22 Commits 01d7f9acd6 ... 4dd66e21f8

Autor SHA1 Nachricht Datum
  myuan 4dd66e21f8 select table 修改为 表列表 vor 2 Jahren
  myuan 40a2dedf6f show rebuild vor 2 Jahren
  myuan 3d25c98cd2 check where vor 2 Jahren
  myuan 7c608695e2 添加更多监视和测试 vor 2 Jahren
  myuan 591c80b4ce 添加drop vor 2 Jahren
  myuan a9a13f8579 初级select vor 2 Jahren
  myuan 0eb11c9d94 删除报错 vor 2 Jahren
  myuan f54cdcc5b1 添加包 vor 2 Jahren
  myuan 3e4e7c83e2 update vor 2 Jahren
  myuan 0faa15ac4d 创建表相关的utils和JSON构造 vor 2 Jahren
  myuan ab790f9351 解析命令行和新建表 vor 2 Jahren
  myuan 3c1f33616b del print vor 2 Jahren
  myuan e5a969dae9 添加包 vor 2 Jahren
  myuan 3e79d1ee4a 输出stderr vor 2 Jahren
  myuan 09c88d99cb 表相关操作 vor 2 Jahren
  myuan 89020f77bc 修改监视 vor 2 Jahren
  myuan 6c47ebee34 parse_sql vor 2 Jahren
  myuan d1363f6a11 修改部分命名 vor 2 Jahren
  myuan 6bfda18d2a del vor 2 Jahren
  myuan 081c87ce02 添加drop测试 vor 2 Jahren
  myuan 7c8f8b20d5 拆解测试代码 vor 2 Jahren
  myuan 0f26f06095 rename vor 2 Jahren
11 geänderte Dateien mit 541 neuen und 172 gelöschten Zeilen
  1. 2 0
      .gitignore
  2. 137 0
      run_sql_parser_test.py
  3. 160 0
      src/checker.cpp
  4. 0 0
      src/checker.h
  5. 0 7
      src/main.c
  6. 1 0
      src/parser.l
  7. 40 12
      src/parser.y
  8. 76 0
      src/utils.cpp
  9. 48 0
      src/utils.h
  10. 69 150
      tests_config.py
  11. 8 3
      xmake.lua

+ 2 - 0
.gitignore

@@ -9,3 +9,5 @@ build/
 *.tab.*
 *.yy.*
 .vscode
+*.pyc
+tables.json

+ 137 - 0
run_sql_parser_test.py

@@ -0,0 +1,137 @@
+from pathlib import Path
+import asyncio.subprocess as subprocess
+import asyncio
+from watchfiles import awatch
+from termcolor import colored
+from datetime import datetime
+import orjson
+import os
+import tempfile
+import tests_config
+import importlib
+importlib.reload(tests_config)
+
+sql_parser_tests, sql_checker_tests = tests_config.sql_parser_tests, tests_config.sql_checker_tests
+
+
+async def run_and_output(
+    *args: str, timeout=10
+) -> tuple[bytes, bytes]:
+    p = await subprocess.create_subprocess_exec(
+        *args,
+        stdout=subprocess.PIPE,
+        stderr=subprocess.PIPE,
+    )
+    stdout, stderr = await asyncio.wait_for(p.communicate(), timeout=timeout)
+    return stdout, stderr
+
+async def rebuild() -> bool:
+    print(datetime.now(), colored('rebuild...', "grey"))
+    stdout, _ = await run_and_output('xmake')
+    if b"error" in stdout:
+        print(stdout.decode("utf-8"))
+        print(datetime.now(), "-" * 40)
+        return False
+    else:
+        return True
+
+async def assert_sql(sql: str, expected: dict):
+    stdout, stderr = await run_and_output('xmake', 'run', "sql-parser", sql)
+
+    if b"error" in stdout:
+        print(stdout.decode("utf-8"))
+        print(datetime.now(), "-" * 40)
+        print(f'other: {colored(stderr.decode("utf-8"), "yellow")}')
+        assert False, "sql-parser error"
+
+    try:
+        output = orjson.loads(stdout)
+    except Exception as e:
+        output = {"error": e, "output": stdout.decode("utf-8")}
+    open("/tmp/temp/test.py", "wb").write(
+        f'"{sql}"\n\n'.encode("utf-8")
+        + orjson.dumps(output, option=orjson.OPT_INDENT_2)
+        + (b"\n\n" + stderr).replace(b"\n", b"\n# ")
+    )
+    assert (
+        output == expected
+    ), f"""{colored("sql-parser error", "red")}
+input: {colored(sql, "yellow")}
+expect: {colored(expected, "green")}
+actual: {colored(output, "red")}
+other: {colored(stderr.decode("utf-8"), "yellow")}
+
+"""
+
+
+async def assert_sqls():
+    for sql, excepted in sql_parser_tests:
+        await assert_sql(sql, excepted)
+
+
+async def on_parser_modified():
+    print(datetime.now(), colored("run parser tests...", "yellow"))
+
+    try:
+        await assert_sqls()
+    except Exception as e:
+        print(e)
+    else:
+        print(datetime.now(), colored("all parser tests right!", "green"))
+
+
+async def assert_checks():
+    for sql, res in sql_checker_tests:
+        stdout, stderr = await run_and_output(
+            'xmake', 'run', "sql-checker", 
+            "-s", sql
+        )
+        print(sql, res)
+        if res is True:
+            assert b'error' not in stdout, stdout.decode("utf-8")
+            assert b'error' not in stderr, stderr.decode('utf-8')
+        elif isinstance(res, str):
+            res = res.encode('utf-8') 
+            assert res in stderr, stderr.decode("utf-8")
+        else:
+            assert False, f"{res} 不是合适的结果"
+
+async def on_checker_modified():
+    print(datetime.now(), colored("run checker tests...", "yellow"))
+    try:
+        await assert_checks()
+    except Exception as e:
+        print(e)
+    print(datetime.now(), colored("all checker tests right!", "green"))
+
+
+async def restart():
+    async for _ in awatch(__file__, "./tests_config.py"):
+        print("restart")
+        os.execl("/bin/python", Path(__file__).as_posix(), Path(__file__).as_posix())
+
+
+async def watch_parser():
+    async for changes in awatch("./src/parser.y", "./src/parser.l"):
+        if await rebuild():
+            await asyncio.wait_for(on_parser_modified(), 10)
+
+
+async def watch_checker():
+    async for changes in awatch("./src/checker.cpp", "./src/checker.h", "./src/utils.h", "./src/utils.cpp"):
+        if await rebuild():
+            await on_checker_modified()
+
+
+async def main():
+    await asyncio.gather(
+        restart(),
+        watch_parser(),
+        watch_checker(),
+        on_parser_modified(),
+        on_checker_modified(),
+    )
+
+
+if __name__ == "__main__":
+    asyncio.run(main())

+ 160 - 0
src/checker.cpp

@@ -0,0 +1,160 @@
+#include <fmt/core.h>
+#include <stdio.h>
+
+#include <CLI/CLI.hpp>
+#include <optional>
+
+#include "utils.h"
+
+using json = nlohmann::json;
+using std::nullopt;
+using std::optional;
+using std::string;
+
+auto tables = ExistTables("./tables.json");
+
+void create_table(const json& j) {
+    auto table_name = table_name_of(j);
+
+    if (tables.exists(table_name)) {
+        throw std::runtime_error(
+            fmt::format("table `{}` exists\n", table_name));
+    }
+    tables.set(table_name, j["cols"]);
+    tables.save();
+}
+
+optional<TableCol> get_select_column_define(const string table_name,
+                                           const json& select_col) {
+    auto type = type_of(select_col);
+    if (type == "identifier") {
+        auto cols = tables[table_name];
+        string colname = select_col["value"];
+        for (auto& col : cols) {
+            if (col.column_name == colname) {
+                return col;
+            }
+        }
+        return nullopt;
+    } else if (type == "int" || type == "string" || type == "float") {
+        return TableCol{.column_name = select_col["value"].dump(),
+                        .data_type = type,
+                        .type = "const",
+                        .primary_key = false};
+    } else if (type == "select_all_column") {
+        return TableCol{.column_name = "*",
+                        .data_type = "",
+                        .type = "select_all_column",
+                        .primary_key = false};
+    }
+    return nullopt;
+}
+
+void check_where_clause(const string table_name, const json& j) {
+    if (j.empty()) {
+        return;
+    }
+
+    auto type = type_of(j);
+    auto left = j["left"];
+    auto right = j["right"];
+    for(const auto curr_branch_name : {"left", "right"}) {
+        auto curr_branch = j[curr_branch_name];
+        if (left.contains("value")) {
+            // 说明到叶节点了
+            auto col_def = get_select_column_define(table_name, curr_branch);
+            if (!col_def.has_value()) {
+                throw std::runtime_error(fmt::format(
+                    "column `{}` not exists in `{}`\n",
+                    curr_branch["value"].dump(2), table_name));
+            }
+        } else {
+            check_where_clause(table_name, curr_branch);
+        }
+    }
+}
+
+void check_select_table(const string table_name, const json& select_cols) {
+    for (auto& select_col : select_cols) {
+        if (select_col["type"] == "select_all_column") {
+            continue;
+        }
+        auto col_def =
+            get_select_column_define(table_name, select_col["target"]);
+
+        if (!col_def.has_value()) {
+            throw std::runtime_error(
+                fmt::format("column `{}` not exists in `{}`\n",
+                            select_col["target"]["value"].dump(2), table_name));
+        }
+    }
+}
+
+void process_sql(json& j) {
+    let type = j.value("type", "none");
+
+    if (type == "create_stmt") {
+        // 创建表
+        create_table(j);
+    } else {
+        if (type == "drop_stmt") {
+            if (tables.exists(j["table_name"])) {
+                tables.remove(j["table_name"]);
+                tables.save();
+            }
+            return;
+        }
+
+        // 其余都会要求表存在
+        let table_name = table_name_of(j);
+        if (!tables.exists(table_name)) {
+            fmt::print("error: table `{}` does not exist\n", table_name);
+            return;
+        }
+
+        if (type == "select_stmt") {
+            check_select_table(table_name, j["select_cols"]);
+            check_where_clause(table_name, j["where"]);
+        } else {
+            throw std::runtime_error(fmt::format(
+                "Unknown expression type `{}` total: \n{}", type, j.dump(2)));
+        }
+    }
+}
+
+int main(int argc, char** argv) {
+    CLI::App app{"sql 检查器, 附带创建表相关功能"};
+    string sql = "", filename = "";
+
+    app.add_option("-f,--file", filename, "输入的 sql 解析后的 json 文件")
+        ->check(CLI::ExistingFile);
+    app.add_option("-s,--sql", sql, "输入的 sql");
+
+    CLI11_PARSE(app, argc, argv);
+
+    json stmts;
+    if (sql.length() > 0) {
+        stmts = parse_sql(sql).out;
+    } else if (filename.length() > 0) {
+        stmts = json::parse(std::ifstream(filename));
+    } else {
+        std::cout << "请输入 sql 或者 sql 文件" << std::endl;
+        return 1;
+    }
+
+    for (auto& stmt : stmts) {
+        // for(auto& item : stmt.items()) {
+        //     fmt::print("{}: {}\n", item.key(), item.value().dump());
+        // }
+        process_sql(stmt);
+    }
+
+    // auto t = res.out[0].value("type", "default");
+    // fmt::print("{}\n", t);
+
+    // tables.set("t1", {"a", "b", "c"});
+    // tables.set("t3", {"asdge", "safaw", "qwer"});
+    // fmt::print("tables.exists {}\n", tables.exists("t1"));
+    // tables.save();
+    return 0;
+}

+ 0 - 0
src/checker.h


+ 0 - 7
src/main.c

@@ -1,7 +0,0 @@
-#include <stdio.h>
-
-int main(int argc, char** argv)
-{
-    printf("hello world!\n");
-    return 0;
-}

+ 1 - 0
src/parser.l

@@ -36,6 +36,7 @@ extern YYSTYPE yylval;
 "SET"		{return SET;}
 "JOIN"      {return JOIN;}
 "TABLE"		{return TABLE;}
+"DROP"		{return DROP;}
 
 "INT"       {cp_yylval_and_return(INT_T);}
 "FLOAT"     {cp_yylval_and_return(FLOAT_T);}

+ 40 - 12
src/parser.y

@@ -59,7 +59,7 @@ cJSON* jroot;
 
 %token IDENTIFIER 
 
-%token SELECT FROM WHERE INSERT INTO VALUES DELETE UPDATE SET JOIN CREATE TABLE 
+%token SELECT FROM WHERE INSERT INTO VALUES DELETE UPDATE SET JOIN CREATE TABLE DROP
 %token AS ON
 %token AND OR NOT IN
 %token INT_V FLOAT_V STRING_V // 作为 value 出现的
@@ -72,16 +72,17 @@ cJSON* jroot;
 
 %type <iv> INT_V
 %type <fv> FLOAT_V
-%type <sv> STRING_V NOT
+%type <sv> STRING_V NOT table_name
 %type <sv> IDENTIFIER data_type PRIMARY_KEY col_options bin_cmp_op bin_logical_op unary_compare_op bin_contains_op
 %type <jv> create_definition create_col_list create_table_stmt data_value
 %type <jv> insert_stmt insert_list 
 %type <jv> update_stmt update_list single_assign_item
 %type <jv> where_condition_item identifier identifier_or_const_value 
-%type <jv> delete_stmt select_stmt select_item select_items 
+%type <jv> delete_stmt select_stmt select_item select_items drop_stmt
 %type <jv> data_value_list identifier_or_const_value_or_const_value_list
-%type <jv> search_expr compare_expr single_expr expr where_expr logical_expr negative_expr op_where_expr expr_list contains_expr
+%type <jv> compare_expr single_expr where_expr logical_expr negative_expr op_where_expr expr_list contains_expr
 %type <jv> op_join table_field column_name
+%type <jv> table_name_list
 
 
 %left OR
@@ -109,12 +110,13 @@ sql_statement: create_table_stmt NEWLINE {cJSON_AddItemToArray(jroot, $1);}
 	| update_stmt NEWLINE {cJSON_AddItemToArray(jroot, $1);}
 	| delete_stmt NEWLINE {cJSON_AddItemToArray(jroot, $1);}
 	| select_stmt NEWLINE {cJSON_AddItemToArray(jroot, $1);}
+	| drop_stmt   NEWLINE {cJSON_AddItemToArray(jroot, $1);}
 ;
 
 create_table_stmt: CREATE TABLE IDENTIFIER
 	{
 		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", "create_table");
+		cJSON_AddStringToObject(node, "type", "create_stmt");
 		cJSON_AddStringToObject(node, "table_name", $3);
 		cJSON_AddItemToObject(node, "cols", cJSON_CreateArray());
 		$$ = node;
@@ -123,7 +125,7 @@ create_table_stmt: CREATE TABLE IDENTIFIER
 
 create_table_stmt: CREATE TABLE IDENTIFIER '(' create_col_list ')' {
 		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", "create_table");
+		cJSON_AddStringToObject(node, "type", "create_stmt");
 		cJSON_AddStringToObject(node, "table_name", $3);
 		cJSON_AddItemToObject(node, "cols", $5);
 		$$ = node;
@@ -169,7 +171,7 @@ data_type: INT_T
 
 insert_stmt: INSERT INTO IDENTIFIER VALUES '(' insert_list ')' {
 		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", "insert");
+		cJSON_AddStringToObject(node, "type", "insert_stmt");
 		cJSON_AddStringToObject(node, "table_name", $3);
 		cJSON_AddItemToObject(node, "values", $6);
 		$$=node;
@@ -194,7 +196,7 @@ data_value: INT_V {SIMPLE_TYPE_VALUE_OBJECT($$, int, Number, $1);}
 
 update_stmt: UPDATE IDENTIFIER SET update_list WHERE where_expr {
 		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", "update");
+		cJSON_AddStringToObject(node, "type", "update_stmt");
 		cJSON_AddStringToObject(node, "table_name", $2);
 		cJSON_AddItemToObject(node, "set", $4);
 		cJSON_AddItemToObject(node, "where", $6);
@@ -363,7 +365,7 @@ bin_contains_op: IN {$$ = "包含于";}
 
 delete_stmt: DELETE FROM IDENTIFIER op_where_expr {
 		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", "delete");
+		cJSON_AddStringToObject(node, "type", "delete_stmt");
 		cJSON_AddStringToObject(node, "table_name", $3);
 		cJSON_AddItemToObject(node, "where", $4);
 		$$=node;
@@ -380,11 +382,29 @@ op_join: {$$ = NULL;}
 	}
 ;
 
-select_stmt: SELECT select_items FROM IDENTIFIER op_join op_where_expr {
+table_name: IDENTIFIER {$$=$1;}
+;
+
+table_name_list: table_name {
+		MEET_VAR(table_name, $1);
+		cJSON* node = cJSON_CreateArray();
+		cJSON_AddItemToArray(node, cJSON_CreateString($1));
+		$$=node;
+	}
+	| table_name_list ',' table_name {
+		MEET_VAR(table_name, $3);
+		MEET_VAR(table_name_list, $1);
+
+		cJSON_AddItemToArray($1, $3);
+		$$=$1;
+	}
+;
+
+select_stmt: SELECT select_items FROM table_name_list op_join op_where_expr {
 		cJSON* node = cJSON_CreateObject();
-		cJSON_AddStringToObject(node, "type", "select");
+		cJSON_AddStringToObject(node, "type", "select_stmt");
 		cJSON_AddItemToObject(node, "select_cols", $2);
-		cJSON_AddStringToObject(node, "table_name", $4);
+		cJSON_AddItemToObject(node, "table_names", $4);
 		if ($5 != NULL) {
 			cJSON_AddItemToObject(node, "join_options", $5);
 		}
@@ -424,6 +444,14 @@ select_item: single_expr {
 	}
 ;
 
+drop_stmt: DROP TABLE IDENTIFIER {
+		cJSON* node = cJSON_CreateObject();
+		cJSON_AddStringToObject(node, "type", "drop_stmt");
+		cJSON_AddStringToObject(node, "table_name", $3);
+		$$=node;
+	}
+;
+
 %%
 
 int main(int ac, char** av) {

+ 76 - 0
src/utils.cpp

@@ -0,0 +1,76 @@
+#include "utils.h"
+
+#include <fmt/core.h>
+
+#include <fstream>
+#include <map>
+#include <nlohmann/json.hpp>
+#include <string>
+#include <vector>
+
+using json = nlohmann::json;
+
+SQLParserRes parse_sql(const std::string& sql) {
+    SQLParserRes res;
+    // auto stdout_name = std::tmpnam(nullptr);
+    // auto stderr_name = std::tmpnam(nullptr);
+    auto stdout_name = "/tmp/sql-parser-stdout";
+    auto stderr_name = "/tmp/sql-parser-stderr";
+
+    auto cmd = fmt::format("xmake run sql-parser \"{}\" > {} 2> {}", sql,
+                           stdout_name, stderr_name);
+    res.exit_code = system(cmd.c_str());
+
+    std::ifstream stdout_f(stdout_name);
+    json stdout_j = json::parse(stdout_f);
+    res.out = stdout_j;
+
+    std::ifstream stderr_f(stderr_name);
+
+    res.err.assign((std::istreambuf_iterator<char>(stderr_f)),
+                   (std::istreambuf_iterator<char>()));
+
+    return res;
+}
+
+ExistTables::ExistTables(const char* file_name) {
+    table_file_name = file_name;
+    std::ifstream ifile(table_file_name);
+
+    if (!ifile.good()) {
+        std::ofstream ofile(table_file_name);
+        ofile << "{}";
+        ofile.close();
+    }
+    ifile.close();
+    read_from_file();
+}
+ExistTables* ExistTables::read_from_file() {
+    std::ifstream f(table_file_name);
+    json j = json::parse(f);
+    for (auto& item : j.items()) {
+        tables[item.key()] = item.value();
+    }
+    return this;
+};
+
+bool ExistTables::exists(const std::string& table_name) {
+    return tables.find(table_name) != tables.end();
+}
+void ExistTables::set(const std::string& table_name, const TableCols& cols) {
+    tables[table_name] = cols;
+}
+TableCols ExistTables::operator[](const std::string& table_name) {
+    return tables[table_name];
+}
+void ExistTables::remove(const std::string& table_name) {
+    tables.erase(table_name);
+}
+void ExistTables::save() {
+    json j;
+    for (auto& item : tables) {
+        j[item.first] = item.second;
+    }
+    std::ofstream f(table_file_name);
+    f << j.dump(2);
+}

+ 48 - 0
src/utils.h

@@ -0,0 +1,48 @@
+#pragma once
+
+#include <map>
+#include <nlohmann/json.hpp>
+#include <vector>
+
+#define let const auto
+
+struct SQLParserRes {
+    int exit_code;
+    nlohmann::json out;
+    std::string err;
+};
+
+struct TableCol {
+    std::string column_name;
+    std::string type;
+    std::string data_type;
+    bool primary_key;
+    NLOHMANN_DEFINE_TYPE_INTRUSIVE(TableCol, column_name, type, data_type,
+                                   primary_key);
+};
+
+using TableCols = std::vector<TableCol>;
+
+SQLParserRes parse_sql(const std::string& sql);
+
+class ExistTables {
+   public:
+    std::string table_file_name;
+    std::map<std::string, TableCols> tables;
+
+    ExistTables(const char* file_name = "./tables.json");
+    ExistTables* read_from_file();
+    bool exists(const std::string& table_name);
+    void set(const std::string& table_name, const TableCols& cols);
+    TableCols operator[](const std::string& table_name);
+    void remove(const std::string& table_name);
+    void save();
+};
+
+inline auto table_name_of(const nlohmann::json& j) {
+    return j.value("table_name", "none");
+}
+inline auto type_of(const nlohmann::json& j) {
+    return j.value("type", "none");
+}
+

+ 69 - 150
run_test.py → tests_config.py

@@ -1,64 +1,10 @@
-from pathlib import Path
-import sys
-import time
-from watchdog.observers import Observer
-from watchdog.events import FileSystemEventHandler
-from watchdog.events import LoggingEventHandler
-import asyncio.subprocess as subprocess
-import asyncio
-from watchfiles import awatch, watch
-from termcolor import colored
-from datetime import datetime
-import orjson
-import os
-
-
-async def assert_sql(sql, target):
-    p = await subprocess.create_subprocess_exec(
-        "xmake",
-        "run",
-        "sql-parser",
-        sql,
-        stdout=subprocess.PIPE,
-        stderr=subprocess.PIPE,
-    )
-    stdout, stderr = await asyncio.wait_for(p.communicate(), timeout=5)
-
-    if b"error" in stdout:
-        print(stdout.decode("utf-8"))
-        print(datetime.now(), "-" * 40)
-        print(f'other: {colored(stderr.decode("utf-8"), "yellow")}')
-        assert False, "sql-parser error"
-
-    try:
-        output = orjson.loads(stdout)
-    except Exception as e:
-        output = {"error": e, "output": stdout.decode("utf-8")}
-    open("/tmp/temp/test.py", "wb").write(
-        f'"{sql}"\n\n'.encode("utf-8")
-        + orjson.dumps(output, option=orjson.OPT_INDENT_2)
-        + (b"\n\n" + stderr).replace(b"\n", b"\n# ")
-    )
-    assert (
-        output == target
-    ), f"""{colored("sql-parser error", "red")}
-input: {colored(sql, "yellow")}
-expect: {colored(target, "green")}
-actual: {colored(output, "red")}
-other: {colored(stderr.decode("utf-8"), "yellow")}
-
-"""
-
-
-async def assert_sqls():
-    await assert_sql(
-        "create table asd;", [{"type": "create_table", "table_name": "asd", "cols": []}]
-    )
-    await assert_sql(
+sql_parser_tests = [
+    ("create table asd;", [{"type": "create_stmt", "table_name": "asd", "cols": []}]),
+    (
         "create table tb (col1 INT, col2 string, col3 FLOAT);",
         [
             {
-                "type": "create_table",
+                "type": "create_stmt",
                 "table_name": "tb",
                 "cols": [
                     {
@@ -82,8 +28,8 @@ async def assert_sqls():
                 ],
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         """
     create table tb1 (
         col1 int primary key, 
@@ -92,7 +38,7 @@ async def assert_sqls():
     """,
         [
             {
-                "type": "create_table",
+                "type": "create_stmt",
                 "table_name": "tb1",
                 "cols": [
                     {
@@ -110,8 +56,8 @@ async def assert_sqls():
                 ],
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         """
         create table tb2 (
             x float,
@@ -121,7 +67,7 @@ async def assert_sqls():
         """,
         [
             {
-                "type": "create_table",
+                "type": "create_stmt",
                 "table_name": "tb2",
                 "cols": [
                     {
@@ -145,12 +91,12 @@ async def assert_sqls():
                 ],
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         """insert into tb1 values (1, 'foo');""",
         [
             {
-                "type": "insert",
+                "type": "insert_stmt",
                 "table_name": "tb1",
                 "values": [
                     {"type": "int", "value": 1},
@@ -158,12 +104,12 @@ async def assert_sqls():
                 ],
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         """insert into tb1 values (2, 'foo', 'zxc', 1234.234);""",
         [
             {
-                "type": "insert",
+                "type": "insert_stmt",
                 "table_name": "tb1",
                 "values": [
                     {"type": "int", "value": 2},
@@ -173,13 +119,12 @@ async def assert_sqls():
                 ],
             }
         ],
-    )
-
-    await assert_sql(
+    ),
+    (
         "update tb1 set col1=3, col4=4 where col1=2 and col2=4;",
         [
             {
-                "type": "update",
+                "type": "update_stmt",
                 "table_name": "tb1",
                 "set": [
                     {
@@ -208,12 +153,12 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         "update tb1 set col1=3, col4=4 where not not not col1=2 and col2=4 or col3=col2;",
         [
             {
-                "type": "update",
+                "type": "update_stmt",
                 "table_name": "tb1",
                 "set": [
                     {
@@ -259,12 +204,12 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         "delete from tb1 where c1 = 1 and c2= 3 or c3=3;",
         [
             {
-                "type": "delete",
+                "type": "delete_stmt",
                 "table_name": "tb1",
                 "where": {
                     "type": "或",
@@ -289,12 +234,12 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         "delete from tb1 where c1 = 1 and (c2= 3 or c3=3) or (c4='asd');",
         [
             {
-                "type": "delete",
+                "type": "delete_stmt",
                 "table_name": "tb1",
                 "where": {
                     "type": "或",
@@ -327,23 +272,23 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         "select * from t2;",
         [
             {
-                "type": "select",
+                "type": "select_stmt",
                 "select_cols": [{"type": "select_all_column"}],
-                "table_name": "t2",
+                "table_names": ["t2"],
                 "where": {},
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         "select c2 as t from t2 where col1>2;",
         [
             {
-                "type": "select",
+                "type": "select_stmt",
                 "select_cols": [
                     {
                         "type": "select_column",
@@ -351,7 +296,7 @@ async def assert_sqls():
                         "alias": "t",
                     }
                 ],
-                "table_name": "t2",
+                "table_names": ["t2"],
                 "where": {
                     "type": "大于",
                     "left": {"type": "identifier", "value": "col1"},
@@ -359,19 +304,19 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         """SELECT Sname FROM Student WHERE Sno IN (1,2) and c in (3, 4, 5);""",
         [
             {
-                "type": "select",
+                "type": "select_stmt",
                 "select_cols": [
                     {
                         "type": "select_column",
                         "target": {"type": "identifier", "value": "Sname"},
                     }
                 ],
-                "table_name": "Student",
+                "table_names": ["Student"],
                 "where": {
                     "type": "且",
                     "left": {
@@ -394,9 +339,8 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-
-    await assert_sql(
+    ),
+    (
         """SELECT Student.Sname
             FROM Student
             WHERE Sno IN (
@@ -407,7 +351,7 @@ async def assert_sqls():
         """,
         [
             {
-                "type": "select",
+                "type": "select_stmt",
                 "select_cols": [
                     {
                         "type": "select_column",
@@ -418,19 +362,19 @@ async def assert_sqls():
                         },
                     }
                 ],
-                "table_name": "Student",
+                "table_names": ["Student"],
                 "where": {
                     "type": "包含于",
                     "left": {"type": "identifier", "value": "Sno"},
                     "right": {
-                        "type": "select",
+                        "type": "select_stmt",
                         "select_cols": [
                             {
                                 "type": "select_column",
                                 "target": {"type": "identifier", "value": "Sno"},
                             }
                         ],
-                        "table_name": "SC",
+                        "table_names": ["SC"],
                         "where": {
                             "type": "相等",
                             "left": {
@@ -444,8 +388,8 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         """
         select Student.Sname 
         from Student join SC 
@@ -454,7 +398,7 @@ async def assert_sqls():
         """,
         [
             {
-                "type": "select",
+                "type": "select_stmt",
                 "select_cols": [
                     {
                         "type": "select_column",
@@ -465,7 +409,7 @@ async def assert_sqls():
                         },
                     }
                 ],
-                "table_name": "Student",
+                "table_names": ["Student"],
                 "join_options": {
                     "type": "join_options",
                     "join_with": {"type": "identifier", "value": "SC"},
@@ -486,8 +430,8 @@ async def assert_sqls():
                 },
             }
         ],
-    )
-    await assert_sql(
+    ),
+    (
         """
         select Student.Sname 
         from   Student join SC 
@@ -496,7 +440,7 @@ async def assert_sqls():
         """,
         [
             {
-                "type": "select",
+                "type": "select_stmt",
                 "select_cols": [
                     {
                         "type": "select_column",
@@ -507,7 +451,7 @@ async def assert_sqls():
                         },
                     }
                 ],
-                "table_name": "Student",
+                "table_names": ["Student"],
                 "join_options": {
                     "type": "join_options",
                     "join_with": {"type": "identifier", "value": "SC"},
@@ -540,46 +484,21 @@ async def assert_sqls():
                 "where": {},
             }
         ],
-    )
-
-
-async def on_modified(event):
-    p = await subprocess.create_subprocess_shell(
-        "xmake", stdout=subprocess.PIPE, stderr=subprocess.PIPE
-    )
-    stdout, _ = await p.communicate()
-    if b"error" in stdout:
-        print(stdout.decode("utf-8"))
-        print(datetime.now(), "-" * 40)
-        return
-
-    try:
-        await assert_sqls()
-    except Exception as e:
-        print(e)
-    else:
-        print(datetime.now(), colored("all tests right!", "green"))
-
-
-async def restart():
-    async for _ in awatch(__file__):
-        print("restart")
-        os.execl("/bin/python", Path(__file__).as_posix(), Path(__file__).as_posix())
-
-
-async def watch_src():
-    async for changes in awatch("src"):
-        print(datetime.now(), "re run...")
-        await asyncio.wait_for(on_modified(changes), 10)
-
+    ),
+    ('drop table t1;', [{"type": "drop_stmt", "table_name": "t1"}]),
+]
 
-async def main():
-    try:
-        await assert_sqls()
-    except Exception as e:
-        print(e)
-    await asyncio.gather(restart(), watch_src())
+sql_checker_tests = [
+    ('drop table person;', True),
+    ('create table person(name string, age int, classId int);', True),
+    ('select age from person;', True),
+    ('select * from person;', True),
+    ('select gender from person;', 'column `"gender"` not exists in `person`'),
+    ('select 123 from person;', True),
 
+    ('drop table class;', True),
+    ('create table class (id int, grade int, faculty string);', True),
+    ('select * from class where grade = 2 and faculty = \'Computer Science\';', True),
+    ('select * from class where grade = 2 and count=33;', 'column `"count"` not exists in `class`'),
 
-if __name__ == "__main__":
-    asyncio.run(main())
+]

+ 8 - 3
xmake.lua

@@ -1,6 +1,6 @@
 add_rules("mode.debug", "mode.release")
-add_requires("cjson")
-add_requires("fmt")
+add_requires("cjson", "nlohmann_json")
+add_requires("fmt", "cli11")
 
 target("sql-parser")
     add_rules("lex", "yacc")
@@ -14,10 +14,15 @@ target_end()
 
 target("sql-checker")
     set_languages("c++20")
+    set_toolset("cxx", "clang++")
+
     set_kind("binary")
     add_files("src/*.cpp")
     add_includedirs("src")
-    add_packages("fmt")
+    add_packages("fmt", "nlohmann_json", "cli11")
+    -- add_cxxflags("-ftime-trace", {force = true})
+    -- set_optimize("none")
+
 target_end()
 
 --