From cf726ab1778bd6e801cdd7aa18c0a8a9633f30b2 Mon Sep 17 00:00:00 2001 From: Roman Hergenreder Date: Thu, 2 Apr 2020 15:08:14 +0200 Subject: [PATCH] Postgres Support --- core/Api/GetApiKeys.class.php | 3 +- core/Api/RefreshApiKey.class.php | 3 +- core/Api/RevokeApiKey.class.php | 3 +- core/Documents/Install.class.php | 4 +- core/Driver/SQL/MySQL.class.php | 41 ++++++-- core/Driver/SQL/PostgreSQL.class.php | 143 ++++++++++++++++++++++++--- core/Driver/SQL/SQL.class.php | 8 +- core/Objects/User.class.php | 5 +- test/apiTest.py | 21 ++++ test/installTest.py | 41 +++----- test/phpTest.py | 36 +++++++ test/requirements.txt | 1 + test/test.py | 45 ++++++++- 13 files changed, 288 insertions(+), 66 deletions(-) create mode 100644 test/apiTest.py create mode 100644 test/phpTest.py diff --git a/core/Api/GetApiKeys.class.php b/core/Api/GetApiKeys.class.php index 8d53594..c391188 100644 --- a/core/Api/GetApiKeys.class.php +++ b/core/Api/GetApiKeys.class.php @@ -2,7 +2,6 @@ namespace Api; -use \Driver\SQL\Keyword; use \Driver\SQL\Condition\Compare; class GetApiKeys extends Request { @@ -21,7 +20,7 @@ class GetApiKeys extends Request { $res = $sql->select("uid", "api_key", "valid_until") ->from("ApiKey") ->where(new Compare("user_id", $this->user->getId())) - ->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">")) + ->where(new Compare("valid_until", $sql->currentTimestamp(), ">")) ->where(new Compare("active", true)) ->execute(); diff --git a/core/Api/RefreshApiKey.class.php b/core/Api/RefreshApiKey.class.php index 3a24a30..f584f6b 100644 --- a/core/Api/RefreshApiKey.class.php +++ b/core/Api/RefreshApiKey.class.php @@ -3,7 +3,6 @@ namespace Api; use \Api\Parameter\Parameter; -use \Driver\SQL\Keyword; use \Driver\SQL\Condition\Compare; class RefreshApiKey extends Request { @@ -23,7 +22,7 @@ class RefreshApiKey extends Request { ->from("ApiKey") ->where(new Compare("uid", $id)) ->where(new Compare("user_id", $this->user->getId())) - ->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">")) + ->where(new Compare("valid_until", $sql->currentTimestamp(), ">")) ->where(new Compare("active", 1)) ->execute(); diff --git a/core/Api/RevokeApiKey.class.php b/core/Api/RevokeApiKey.class.php index f2b8448..a769449 100644 --- a/core/Api/RevokeApiKey.class.php +++ b/core/Api/RevokeApiKey.class.php @@ -3,7 +3,6 @@ namespace Api; use \Api\Parameter\Parameter; -use \Driver\SQL\Keyword; use \Driver\SQL\Condition\Compare; class RevokeApiKey extends Request { @@ -23,7 +22,7 @@ class RevokeApiKey extends Request { ->from("ApiKey") ->where(new Compare("uid", $id)) ->where(new Compare("user_id", $this->user->getId())) - ->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">")) + ->where(new Compare("valid_until", $sql->currentTimestamp(), ">")) ->where(new Compare("active", 1)) ->execute(); diff --git a/core/Documents/Install.class.php b/core/Documents/Install.class.php index 3f0edc8..1189317 100644 --- a/core/Documents/Install.class.php +++ b/core/Documents/Install.class.php @@ -97,7 +97,9 @@ namespace Documents\Install { return self::DATABASE_CONFIGURATION; } - $res = $user->getSQL()->select("COUNT(*) as count")->from("User")->execute(); + $sql = $user->getSQL(); + $countKeyword = $sql->count(); + $res = $sql->select($countKeyword)->from("User")->execute(); if ($res === FALSE) { return self::DATABASE_CONFIGURATION; } else { diff --git a/core/Driver/SQL/MySQL.class.php b/core/Driver/SQL/MySQL.class.php index 2e2e5f2..b72f03e 100644 --- a/core/Driver/SQL/MySQL.class.php +++ b/core/Driver/SQL/MySQL.class.php @@ -26,14 +26,13 @@ class MySQL extends SQL { public function __construct($connectionData) { parent::__construct($connectionData); - $this->installLink = ; } public function checkRequirements() { return function_exists('mysqli_connect'); } - public abstract function getDriverName() { + public function getDriverName() { return 'mysqli'; } @@ -248,7 +247,16 @@ class MySQL extends SQL { public function executeSelect($select) { - $columns = implode(",", $select->getColumns()); + $columns = array(); + foreach($select->getColumns() as $col) { + if ($col instanceof Keyword) { + $columns[] = $col->getValue(); + } else { + $columns[] = "`$col`"; + } + } + + $columns = implode(",", $columns); $tables = $select->getTables(); $params = array(); @@ -364,7 +372,7 @@ class MySQL extends SQL { if (!is_null($column->getDefaultValue()) || !$column->notNull()) { $defaultValue = " DEFAULT " . $this->getValueDefinition($column->getDefaultValue()); } - + return "`$columnName` $type$notNull$defaultValue"; } @@ -416,7 +424,28 @@ class MySQL extends SQL { } } - public function currentTimestamp() { - return "NOW()"; + protected function tableName($table) { + return "`$table`"; } + + protected function columnName($col) { + if ($col instanceof KeyWord) { + return $col->getValue(); + } else { + return "`$col`"; + } + } + + public function currentTimestamp() { + return new KeyWord("NOW()"); + } + + public function count($col = NULL) { + if (is_null($col)) { + return new Keyword("COUNT(*)"); + } else { + return new Keyword("COUNT($col)"); + } + } + }; diff --git a/core/Driver/SQL/PostgreSQL.class.php b/core/Driver/SQL/PostgreSQL.class.php index ca0d108..90a3f43 100644 --- a/core/Driver/SQL/PostgreSQL.class.php +++ b/core/Driver/SQL/PostgreSQL.class.php @@ -137,7 +137,7 @@ class PostgreSQL extends SQL { // Querybuilder public function executeCreateTable($createTable) { - $tableName = $createTable->getTableName(); + $tableName = $this->tableName($createTable->getTableName()); $ifNotExists = $createTable->ifNotExists() ? " IF NOT EXISTS": ""; $entries = array(); @@ -156,13 +156,13 @@ class PostgreSQL extends SQL { } $entries = implode(",", $entries); - $query = "CREATE TABLE$ifNotExists \"$tableName\" ($entries)"; + $query = "CREATE TABLE$ifNotExists $tableName ($entries)"; return $this->execute($query); } public function executeInsert($insert) { - $tableName = $insert->getTableName(); + $tableName = $this->tableName($insert->getTableName()); $columns = $insert->getColumns(); $rows = $insert->getRows(); $onDuplicateKey = $insert->onDuplicateKey() ?? ""; @@ -173,11 +173,15 @@ class PostgreSQL extends SQL { } if (is_null($columns) || empty($columns)) { - $columns = ""; + $columnStr = ""; $numColumns = count($rows[0]); } else { $numColumns = count($columns); - $columns = " (\"" . implode("\", \"", $columns) . "\")"; + $columnStr = array(); + foreach($columns as $col) { + $columnStr[] = $this->columnName($col); + } + $columnStr = " (" . implode(",", $columnStr) . ")"; } $numRows = count($rows); @@ -196,7 +200,7 @@ class PostgreSQL extends SQL { $values = implode(",", $values); if ($onDuplicateKey) { - if ($onDuplicateKey instanceof UpdateStrategy) { + /*if ($onDuplicateKey instanceof UpdateStrategy) { $updateValues = array(); foreach($onDuplicateKey->getValues() as $key => $value) { if ($value instanceof Column) { @@ -208,7 +212,7 @@ class PostgreSQL extends SQL { } $onDuplicateKey = " ON CONFLICT DO UPDATE SET " . implode(",", $updateValues); - } else { + } else*/ { $strategy = get_class($onDuplicateKey); $this->lastError = "ON DUPLICATE Strategy $strategy is not supported yet."; return false; @@ -216,9 +220,9 @@ class PostgreSQL extends SQL { } $returningCol = $insert->getReturning(); - $returning = $returningCol ? " RETURNING \"$returningCol\"" : ""; + $returning = $returningCol ? (" RETURNING " . $this->columnName($returningCol)) : ""; - $query = "INSERT INTO \"$tableName\"$columns VALUES$values$onDuplicateKey$returning"; + $query = "INSERT INTO $tableName$columnStr VALUES$values$onDuplicateKey$returning"; $res = $this->execute($query, $parameters, !empty($returning)); $success = ($res !== FALSE); @@ -229,11 +233,93 @@ class PostgreSQL extends SQL { return $success; } - // TODO: - public function executeSelect($query) { } - public function executeDelete($query) { } - public function executeTruncate($query) { } - public function executeUpdate($query) { } + public function executeSelect($select) { + + $columns = array(); + foreach($select->getColumns() as $col) { + $columns[] = $this->columnName($col); + } + + $columns = implode(",", $columns); + $tables = $select->getTables(); + $params = array(); + + if (is_null($tables) || empty($tables)) { + return "SELECT $columns"; + } else { + $tableStr = array(); + foreach($tables as $table) { + $tableStr[] = $this->tableName($table); + } + $tableStr = implode(",", $tableStr); + } + + $conditions = $select->getConditions(); + if (!empty($conditions)) { + $condition = " WHERE " . $this->buildCondition($conditions, $params); + } else { + $condition = ""; + } + + $joinStr = ""; + $joins = $select->getJoins(); + if (!empty($joins)) { + $joinStr = ""; + foreach($joins as $join) { + $type = $join->getType(); + $joinTable = $this->tableName($join->getTable()); + $columnA = $this->columnName($join->getColumnA()); + $columnB = $this->columnName($join->getColumnB()); + $joinStr .= " $type JOIN $joinTable ON $columnA=$columnB"; + } + } + + $orderBy = ""; + $limit = ""; + $offset = ""; + + $query = "SELECT $columns FROM $tableStr$joinStr$condition$orderBy$limit$offset"; + return $this->execute($query, $params, true); + } + + public function executeDelete($delete) { + $table = $delete->getTable(); + $conditions = $delete->getConditions(); + if (!empty($conditions)) { + $condition = " WHERE " . $this->buildCondition($conditions, $params); + } else { + $condition = ""; + } + + $query = "DELETE FROM \"$table\"$condition"; + return $this->execute($query); + } + + public function executeTruncate($truncate) { + $table = $truncate->getTable(); + return $this->execute("TRUNCATE \"$table\""); + } + + public function executeUpdate($update) { + $params = array(); + $table = $update->getTable(); + + $valueStr = array(); + foreach($update->getValues() as $key => $val) { + $valueStr[] = "$key=" . $this->addValue($val, $params); + } + $valueStr = implode(",", $valueStr); + + $conditions = $update->getConditions(); + if (!empty($conditions)) { + $condition = " WHERE " . $this->buildCondition($conditions, $params); + } else { + $condition = ""; + } + + $query = "UPDATE \"$table\" SET $valueStr$condition"; + return $this->execute($query, $params); + } // UGLY but.. what should i do? private function createEnum($enumColumn) { @@ -344,9 +430,36 @@ class PostgreSQL extends SQL { } } + protected function tableName($table) { + return "\"$table\""; + } + + protected function columnName($col) { + if ($col instanceof KeyWord) { + return $col->getValue(); + } else { + $index = strrpos($col, "."); + if ($index === FALSE) { + return "\"$col\""; + } else { + $tableName = $this->tableName(substr($col, 0, $index)); + $columnName = $this->columnName(substr($col, $index + 1)); + return "$tableName.$columnName"; + } + } + } + // Special Keywords and functions public function currentTimestamp() { - return "CURRENT_TIMESTAMP"; + return new Keyword("CURRENT_TIMESTAMP"); + } + + public function count($col = NULL) { + if (is_null($col)) { + return new Keyword("COUNT(*)"); + } else { + return new Keyword("COUNT(\"$col\")"); + } } } ?> diff --git a/core/Driver/SQL/SQL.class.php b/core/Driver/SQL/SQL.class.php index 9529a9e..12dda38 100644 --- a/core/Driver/SQL/SQL.class.php +++ b/core/Driver/SQL/SQL.class.php @@ -72,8 +72,12 @@ abstract class SQL { protected abstract function getValueDefinition($val); protected abstract function addValue($val, &$params); + protected abstract function tableName($table); + protected abstract function columnName($col); + // Special Keywords and functions public abstract function currentTimestamp(); + public abstract function count($col = NULL); // Statements protected abstract function execute($query, $values=NULL, $returnValues=false); @@ -86,12 +90,12 @@ abstract class SQL { } return "(" . implode(" OR ", $conditions) . ")"; } else if ($condition instanceof \Driver\SQL\Condition\Compare) { - $column = $condition->getColumn(); + $column = $this->columnName($condition->getColumn()); $value = $condition->getValue(); $operator = $condition->getOperator(); return $column . $operator . $this->addValue($value, $params); } else if ($condition instanceof \Driver\SQL\Condition\CondBool) { - return $condition->getValue(); + return $this->columnName($condition->getValue()); } else if (is_array($condition)) { if (count($condition) == 1) { return $this->buildCondition($condition[0], $params); diff --git a/core/Objects/User.class.php b/core/Objects/User.class.php index f1e9d21..2558bee 100644 --- a/core/Objects/User.class.php +++ b/core/Objects/User.class.php @@ -2,7 +2,6 @@ namespace Objects; -use Driver\SQL\Keyword; use Driver\SQL\Column\Column; use Driver\SQL\Condition\Compare; use Driver\SQL\Condition\CondBool; @@ -113,7 +112,7 @@ class User extends ApiObject { ->where(new Compare("User.uid", $userId)) ->where(new Compare("Session.uid", $sessionId)) ->where(new Compare("Session.active", true)) - ->where(new CondBool("Session.stay_logged_in"), new Compare("Session.expires", new Keyword($this->sql->currentTimestamp()), '>')) + ->where(new CondBool("Session.stay_logged_in"), new Compare("Session.expires", $this->sql->currentTimestamp(), '>')) ->execute(); $success = ($res !== FALSE); @@ -189,7 +188,7 @@ class User extends ApiObject { ->innerJoin("User", "ApiKey.user_id", "User.uid") ->leftJoin("Language", "User.language_id", "Language.uid") ->where(new Compare("ApiKey.api_key", $apiKey)) - ->where(new Compare("valid_until", new Keyword($this->sql->currentTimestamp()), ">")) + ->where(new Compare("valid_until", $this->sql->currentTimestamp(), ">")) ->where(new COmpare("ApiKey.active", 1)) ->execute(); diff --git a/test/apiTest.py b/test/apiTest.py new file mode 100644 index 0000000..3b01207 --- /dev/null +++ b/test/apiTest.py @@ -0,0 +1,21 @@ +import requests +import json + +from phpTest import PhpTest + +class ApiTestCase(PhpTest): + + def __init__(self): + super().__init__("test_api") + self.session = requests.Session() + + def api(self, method): + return "%s/api/%s" % (self.url, method) + + def test_api(self): + + res = self.session.post(self.api("login"), data={ "username": PhpTest.ADMIN_USERNAME, "password": PhpTest.ADMIN_PASSWORD }) + self.assertEquals(200, res.status_code, self.httpError(res)) + self.assertEquals([], self.getPhpErrors(res)) + obj = json.loads(res.text) + self.assertEquals(True, obj["success"], obj["msg"]) diff --git a/test/installTest.py b/test/installTest.py index 724b2db..621c064 100644 --- a/test/installTest.py +++ b/test/installTest.py @@ -1,30 +1,13 @@ -import unittest import requests -import json -import re -import string -import random -class InstallTestCase(unittest.TestCase): +from phpTest import PhpTest + +class InstallTestCase(PhpTest): def __init__(self, args): super().__init__("test_install") self.args = args self.session = requests.Session() - self.url = "http://localhost/" - - keywords = ["Fatal error", "Warning", "Notice", "Parse error", "Deprecated"] - self.phpPattern = re.compile("(%s):" % ("|".join(keywords))) - - def randomString(self, length): - letters = string.ascii_lowercase + string.ascii_uppercase + string.digits - return ''.join(random.choice(letters) for i in range(length)) - - def httpError(self, res): - return "Server returned: %d %s" % (res.status_code, res.reason) - - def getPhpErrors(self, res): - return [line for line in res.text.split("\n") if self.phpPattern.search(line)] def test_install(self): @@ -39,46 +22,44 @@ class InstallTestCase(unittest.TestCase): self.assertEquals([], self.getPhpErrors(res)) # Create User - valid_username = self.randomString(16) - valid_password = self.randomString(16) # 1. Invalid username for username in ["a", "a"*33]: res = self.session.post(self.url, data={ "username": username, "password": "123456", "confirmPassword": "123456" }) self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals([], self.getPhpErrors(res)) - obj = json.loads(res.text) + obj = self.getJson(res) self.assertEquals(False, obj["success"]) self.assertEquals("The username should be between 5 and 32 characters long", obj["msg"]) # 2. Invalid password - res = self.session.post(self.url, data={ "username": valid_username, "password": "1", "confirmPassword": "1" }) + res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": "1", "confirmPassword": "1" }) self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals([], self.getPhpErrors(res)) - obj = json.loads(res.text) + obj = self.getJson(res) self.assertEquals(False, obj["success"]) self.assertEquals("The password should be at least 6 characters long", obj["msg"]) # 3. Passwords do not match - res = self.session.post(self.url, data={ "username": valid_username, "password": "1", "confirmPassword": "2" }) + res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": "1", "confirmPassword": "2" }) self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals([], self.getPhpErrors(res)) - obj = json.loads(res.text) + obj = self.getJson(res) self.assertEquals(False, obj["success"]) self.assertEquals("The given passwords do not match", obj["msg"]) # 4. User creation OK - res = self.session.post(self.url, data={ "username": valid_username, "password": valid_password, "confirmPassword": valid_password }) + res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": PhpTest.ADMIN_PASSWORD, "confirmPassword": PhpTest.ADMIN_PASSWORD }) self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals([], self.getPhpErrors(res)) - obj = json.loads(res.text) + obj = self.getJson(res) self.assertEquals(True, obj["success"]) # Mail: SKIP res = self.session.post(self.url, data={ "skip": "true" }) self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals([], self.getPhpErrors(res)) - obj = json.loads(res.text) + obj = self.getJson(res) self.assertEquals(True, obj["success"]) # Creation successful: diff --git a/test/phpTest.py b/test/phpTest.py new file mode 100644 index 0000000..0c251f9 --- /dev/null +++ b/test/phpTest.py @@ -0,0 +1,36 @@ +import unittest +import string +import random +import re +import json + +class PhpTest(unittest.TestCase): + + def randomString(length): + letters = string.ascii_lowercase + string.ascii_uppercase + string.digits + return ''.join(random.choice(letters) for i in range(length)) + + ADMIN_USERNAME = "Administrator" + ADMIN_PASSWORD = randomString(16) + + def __init__(self, test_method): + super().__init__(test_method) + keywords = ["Fatal error", "Warning", "Notice", "Parse error", "Deprecated"] + self.phpPattern = re.compile("(%s):" % ("|".join(keywords))) + self.url = "http://localhost/" + + def httpError(self, res): + return "Server returned: %d %s" % (res.status_code, res.reason) + + def getPhpErrors(self, res): + return [line for line in res.text.split("\n") if self.phpPattern.search(line)] + + def getJson(self, res): + obj = None + try: + obj = json.loads(res.text) + except: + pass + finally: + self.assertTrue(isinstance(obj, dict), res.text) + return obj diff --git a/test/requirements.txt b/test/requirements.txt index eb910f7..2134b16 100644 --- a/test/requirements.txt +++ b/test/requirements.txt @@ -1,2 +1,3 @@ requests==2.23.0 +psycopg2==2.8.4 mysql_connector_repackaged==0.3.1 diff --git a/test/test.py b/test/test.py index fe61cca..50a57a1 100644 --- a/test/test.py +++ b/test/test.py @@ -6,9 +6,12 @@ import argparse import random import string import unittest + import mysql.connector +import psycopg2 from installTest import InstallTestCase +from apiTest import ApiTestCase CONFIG_FILES = ["../core/Configuration/Database.class.php","../core/Configuration/JWT.class.php","../core/Configuration/Mail.class.php"] @@ -19,15 +22,13 @@ def randomName(length): def performTest(args): suite = unittest.TestSuite() suite.addTest(InstallTestCase(args)) + suite.addTest(ApiTestCase()) runner = unittest.TextTestRunner() runner.run(suite) def testMysql(args): # Create a temporary database - cursor = None - database = None - connection = None if args.database is None: args.database = "webbase_test_%s" % randomName(6) config = { @@ -43,6 +44,7 @@ def testMysql(args): cursor = connection.cursor() print("[ ] Creating temporary databse %s" % args.database) cursor.execute("CREATE DATABASE %s" % args.database) + cursor.commit() print("[+] Success") # perform test @@ -60,6 +62,37 @@ def testMysql(args): print("[ ] Closing connection…") connection.close() +def testPostgres(args): + + # Create a temporary database + if args.database is None: + args.database = "webbase_test_%s" % randomName(6) + connection_string = "host=%s port=%d user=%s password=%s" % (args.host, args.port, args.username, args.password) + + print("[ ] Connecting to dbms…") + connection = psycopg2.connect(connection_string) + connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT) + print("[+] Success") + cursor = connection.cursor() + print("[ ] Creating temporary databse %s" % args.database) + cursor.execute("CREATE DATABASE %s" % args.database) + print("[+] Success") + + # perform test + try: + args.type = "postgres" + performTest(args) + finally: + if cursor is not None: + print("[ ] Deleting temporary database") + cursor.execute("DROP DATABASE %s" % args.database) + cursor.close() + print("[+] Success") + + if connection is not None: + print("[ ] Closing connection…") + connection.close() + if __name__ == "__main__": supportedDbms = { @@ -95,3 +128,9 @@ if __name__ == "__main__": if args.dbms == "mysql": testMysql(args) + elif args.dbms == "postgres": + testPostgres(args) + + for f in CONFIG_FILES: + if os.path.isfile(f): + os.remove(f)