diff --git a/core/Api/GetApiKeys.class.php b/core/Api/GetApiKeys.class.php
index 8d53594..c391188 100644
--- a/core/Api/GetApiKeys.class.php
+++ b/core/Api/GetApiKeys.class.php
@@ -2,7 +2,6 @@
namespace Api;
-use \Driver\SQL\Keyword;
use \Driver\SQL\Condition\Compare;
class GetApiKeys extends Request {
@@ -21,7 +20,7 @@ class GetApiKeys extends Request {
$res = $sql->select("uid", "api_key", "valid_until")
->from("ApiKey")
->where(new Compare("user_id", $this->user->getId()))
- ->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">"))
+ ->where(new Compare("valid_until", $sql->currentTimestamp(), ">"))
->where(new Compare("active", true))
->execute();
diff --git a/core/Api/RefreshApiKey.class.php b/core/Api/RefreshApiKey.class.php
index 3a24a30..f584f6b 100644
--- a/core/Api/RefreshApiKey.class.php
+++ b/core/Api/RefreshApiKey.class.php
@@ -3,7 +3,6 @@
namespace Api;
use \Api\Parameter\Parameter;
-use \Driver\SQL\Keyword;
use \Driver\SQL\Condition\Compare;
class RefreshApiKey extends Request {
@@ -23,7 +22,7 @@ class RefreshApiKey extends Request {
->from("ApiKey")
->where(new Compare("uid", $id))
->where(new Compare("user_id", $this->user->getId()))
- ->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">"))
+ ->where(new Compare("valid_until", $sql->currentTimestamp(), ">"))
->where(new Compare("active", 1))
->execute();
diff --git a/core/Api/RevokeApiKey.class.php b/core/Api/RevokeApiKey.class.php
index f2b8448..a769449 100644
--- a/core/Api/RevokeApiKey.class.php
+++ b/core/Api/RevokeApiKey.class.php
@@ -3,7 +3,6 @@
namespace Api;
use \Api\Parameter\Parameter;
-use \Driver\SQL\Keyword;
use \Driver\SQL\Condition\Compare;
class RevokeApiKey extends Request {
@@ -23,7 +22,7 @@ class RevokeApiKey extends Request {
->from("ApiKey")
->where(new Compare("uid", $id))
->where(new Compare("user_id", $this->user->getId()))
- ->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">"))
+ ->where(new Compare("valid_until", $sql->currentTimestamp(), ">"))
->where(new Compare("active", 1))
->execute();
diff --git a/core/Documents/Install.class.php b/core/Documents/Install.class.php
index 3f0edc8..1189317 100644
--- a/core/Documents/Install.class.php
+++ b/core/Documents/Install.class.php
@@ -97,7 +97,9 @@ namespace Documents\Install {
return self::DATABASE_CONFIGURATION;
}
- $res = $user->getSQL()->select("COUNT(*) as count")->from("User")->execute();
+ $sql = $user->getSQL();
+ $countKeyword = $sql->count();
+ $res = $sql->select($countKeyword)->from("User")->execute();
if ($res === FALSE) {
return self::DATABASE_CONFIGURATION;
} else {
diff --git a/core/Driver/SQL/MySQL.class.php b/core/Driver/SQL/MySQL.class.php
index 2e2e5f2..b72f03e 100644
--- a/core/Driver/SQL/MySQL.class.php
+++ b/core/Driver/SQL/MySQL.class.php
@@ -26,14 +26,13 @@ class MySQL extends SQL {
public function __construct($connectionData) {
parent::__construct($connectionData);
- $this->installLink = ;
}
public function checkRequirements() {
return function_exists('mysqli_connect');
}
- public abstract function getDriverName() {
+ public function getDriverName() {
return 'mysqli';
}
@@ -248,7 +247,16 @@ class MySQL extends SQL {
public function executeSelect($select) {
- $columns = implode(",", $select->getColumns());
+ $columns = array();
+ foreach($select->getColumns() as $col) {
+ if ($col instanceof Keyword) {
+ $columns[] = $col->getValue();
+ } else {
+ $columns[] = "`$col`";
+ }
+ }
+
+ $columns = implode(",", $columns);
$tables = $select->getTables();
$params = array();
@@ -364,7 +372,7 @@ class MySQL extends SQL {
if (!is_null($column->getDefaultValue()) || !$column->notNull()) {
$defaultValue = " DEFAULT " . $this->getValueDefinition($column->getDefaultValue());
}
-
+
return "`$columnName` $type$notNull$defaultValue";
}
@@ -416,7 +424,28 @@ class MySQL extends SQL {
}
}
- public function currentTimestamp() {
- return "NOW()";
+ protected function tableName($table) {
+ return "`$table`";
}
+
+ protected function columnName($col) {
+ if ($col instanceof KeyWord) {
+ return $col->getValue();
+ } else {
+ return "`$col`";
+ }
+ }
+
+ public function currentTimestamp() {
+ return new KeyWord("NOW()");
+ }
+
+ public function count($col = NULL) {
+ if (is_null($col)) {
+ return new Keyword("COUNT(*)");
+ } else {
+ return new Keyword("COUNT($col)");
+ }
+ }
+
};
diff --git a/core/Driver/SQL/PostgreSQL.class.php b/core/Driver/SQL/PostgreSQL.class.php
index ca0d108..90a3f43 100644
--- a/core/Driver/SQL/PostgreSQL.class.php
+++ b/core/Driver/SQL/PostgreSQL.class.php
@@ -137,7 +137,7 @@ class PostgreSQL extends SQL {
// Querybuilder
public function executeCreateTable($createTable) {
- $tableName = $createTable->getTableName();
+ $tableName = $this->tableName($createTable->getTableName());
$ifNotExists = $createTable->ifNotExists() ? " IF NOT EXISTS": "";
$entries = array();
@@ -156,13 +156,13 @@ class PostgreSQL extends SQL {
}
$entries = implode(",", $entries);
- $query = "CREATE TABLE$ifNotExists \"$tableName\" ($entries)";
+ $query = "CREATE TABLE$ifNotExists $tableName ($entries)";
return $this->execute($query);
}
public function executeInsert($insert) {
- $tableName = $insert->getTableName();
+ $tableName = $this->tableName($insert->getTableName());
$columns = $insert->getColumns();
$rows = $insert->getRows();
$onDuplicateKey = $insert->onDuplicateKey() ?? "";
@@ -173,11 +173,15 @@ class PostgreSQL extends SQL {
}
if (is_null($columns) || empty($columns)) {
- $columns = "";
+ $columnStr = "";
$numColumns = count($rows[0]);
} else {
$numColumns = count($columns);
- $columns = " (\"" . implode("\", \"", $columns) . "\")";
+ $columnStr = array();
+ foreach($columns as $col) {
+ $columnStr[] = $this->columnName($col);
+ }
+ $columnStr = " (" . implode(",", $columnStr) . ")";
}
$numRows = count($rows);
@@ -196,7 +200,7 @@ class PostgreSQL extends SQL {
$values = implode(",", $values);
if ($onDuplicateKey) {
- if ($onDuplicateKey instanceof UpdateStrategy) {
+ /*if ($onDuplicateKey instanceof UpdateStrategy) {
$updateValues = array();
foreach($onDuplicateKey->getValues() as $key => $value) {
if ($value instanceof Column) {
@@ -208,7 +212,7 @@ class PostgreSQL extends SQL {
}
$onDuplicateKey = " ON CONFLICT DO UPDATE SET " . implode(",", $updateValues);
- } else {
+ } else*/ {
$strategy = get_class($onDuplicateKey);
$this->lastError = "ON DUPLICATE Strategy $strategy is not supported yet.";
return false;
@@ -216,9 +220,9 @@ class PostgreSQL extends SQL {
}
$returningCol = $insert->getReturning();
- $returning = $returningCol ? " RETURNING \"$returningCol\"" : "";
+ $returning = $returningCol ? (" RETURNING " . $this->columnName($returningCol)) : "";
- $query = "INSERT INTO \"$tableName\"$columns VALUES$values$onDuplicateKey$returning";
+ $query = "INSERT INTO $tableName$columnStr VALUES$values$onDuplicateKey$returning";
$res = $this->execute($query, $parameters, !empty($returning));
$success = ($res !== FALSE);
@@ -229,11 +233,93 @@ class PostgreSQL extends SQL {
return $success;
}
- // TODO:
- public function executeSelect($query) { }
- public function executeDelete($query) { }
- public function executeTruncate($query) { }
- public function executeUpdate($query) { }
+ public function executeSelect($select) {
+
+ $columns = array();
+ foreach($select->getColumns() as $col) {
+ $columns[] = $this->columnName($col);
+ }
+
+ $columns = implode(",", $columns);
+ $tables = $select->getTables();
+ $params = array();
+
+ if (is_null($tables) || empty($tables)) {
+ return "SELECT $columns";
+ } else {
+ $tableStr = array();
+ foreach($tables as $table) {
+ $tableStr[] = $this->tableName($table);
+ }
+ $tableStr = implode(",", $tableStr);
+ }
+
+ $conditions = $select->getConditions();
+ if (!empty($conditions)) {
+ $condition = " WHERE " . $this->buildCondition($conditions, $params);
+ } else {
+ $condition = "";
+ }
+
+ $joinStr = "";
+ $joins = $select->getJoins();
+ if (!empty($joins)) {
+ $joinStr = "";
+ foreach($joins as $join) {
+ $type = $join->getType();
+ $joinTable = $this->tableName($join->getTable());
+ $columnA = $this->columnName($join->getColumnA());
+ $columnB = $this->columnName($join->getColumnB());
+ $joinStr .= " $type JOIN $joinTable ON $columnA=$columnB";
+ }
+ }
+
+ $orderBy = "";
+ $limit = "";
+ $offset = "";
+
+ $query = "SELECT $columns FROM $tableStr$joinStr$condition$orderBy$limit$offset";
+ return $this->execute($query, $params, true);
+ }
+
+ public function executeDelete($delete) {
+ $table = $delete->getTable();
+ $conditions = $delete->getConditions();
+ if (!empty($conditions)) {
+ $condition = " WHERE " . $this->buildCondition($conditions, $params);
+ } else {
+ $condition = "";
+ }
+
+ $query = "DELETE FROM \"$table\"$condition";
+ return $this->execute($query);
+ }
+
+ public function executeTruncate($truncate) {
+ $table = $truncate->getTable();
+ return $this->execute("TRUNCATE \"$table\"");
+ }
+
+ public function executeUpdate($update) {
+ $params = array();
+ $table = $update->getTable();
+
+ $valueStr = array();
+ foreach($update->getValues() as $key => $val) {
+ $valueStr[] = "$key=" . $this->addValue($val, $params);
+ }
+ $valueStr = implode(",", $valueStr);
+
+ $conditions = $update->getConditions();
+ if (!empty($conditions)) {
+ $condition = " WHERE " . $this->buildCondition($conditions, $params);
+ } else {
+ $condition = "";
+ }
+
+ $query = "UPDATE \"$table\" SET $valueStr$condition";
+ return $this->execute($query, $params);
+ }
// UGLY but.. what should i do?
private function createEnum($enumColumn) {
@@ -344,9 +430,36 @@ class PostgreSQL extends SQL {
}
}
+ protected function tableName($table) {
+ return "\"$table\"";
+ }
+
+ protected function columnName($col) {
+ if ($col instanceof KeyWord) {
+ return $col->getValue();
+ } else {
+ $index = strrpos($col, ".");
+ if ($index === FALSE) {
+ return "\"$col\"";
+ } else {
+ $tableName = $this->tableName(substr($col, 0, $index));
+ $columnName = $this->columnName(substr($col, $index + 1));
+ return "$tableName.$columnName";
+ }
+ }
+ }
+
// Special Keywords and functions
public function currentTimestamp() {
- return "CURRENT_TIMESTAMP";
+ return new Keyword("CURRENT_TIMESTAMP");
+ }
+
+ public function count($col = NULL) {
+ if (is_null($col)) {
+ return new Keyword("COUNT(*)");
+ } else {
+ return new Keyword("COUNT(\"$col\")");
+ }
}
}
?>
diff --git a/core/Driver/SQL/SQL.class.php b/core/Driver/SQL/SQL.class.php
index 9529a9e..12dda38 100644
--- a/core/Driver/SQL/SQL.class.php
+++ b/core/Driver/SQL/SQL.class.php
@@ -72,8 +72,12 @@ abstract class SQL {
protected abstract function getValueDefinition($val);
protected abstract function addValue($val, &$params);
+ protected abstract function tableName($table);
+ protected abstract function columnName($col);
+
// Special Keywords and functions
public abstract function currentTimestamp();
+ public abstract function count($col = NULL);
// Statements
protected abstract function execute($query, $values=NULL, $returnValues=false);
@@ -86,12 +90,12 @@ abstract class SQL {
}
return "(" . implode(" OR ", $conditions) . ")";
} else if ($condition instanceof \Driver\SQL\Condition\Compare) {
- $column = $condition->getColumn();
+ $column = $this->columnName($condition->getColumn());
$value = $condition->getValue();
$operator = $condition->getOperator();
return $column . $operator . $this->addValue($value, $params);
} else if ($condition instanceof \Driver\SQL\Condition\CondBool) {
- return $condition->getValue();
+ return $this->columnName($condition->getValue());
} else if (is_array($condition)) {
if (count($condition) == 1) {
return $this->buildCondition($condition[0], $params);
diff --git a/core/Objects/User.class.php b/core/Objects/User.class.php
index f1e9d21..2558bee 100644
--- a/core/Objects/User.class.php
+++ b/core/Objects/User.class.php
@@ -2,7 +2,6 @@
namespace Objects;
-use Driver\SQL\Keyword;
use Driver\SQL\Column\Column;
use Driver\SQL\Condition\Compare;
use Driver\SQL\Condition\CondBool;
@@ -113,7 +112,7 @@ class User extends ApiObject {
->where(new Compare("User.uid", $userId))
->where(new Compare("Session.uid", $sessionId))
->where(new Compare("Session.active", true))
- ->where(new CondBool("Session.stay_logged_in"), new Compare("Session.expires", new Keyword($this->sql->currentTimestamp()), '>'))
+ ->where(new CondBool("Session.stay_logged_in"), new Compare("Session.expires", $this->sql->currentTimestamp(), '>'))
->execute();
$success = ($res !== FALSE);
@@ -189,7 +188,7 @@ class User extends ApiObject {
->innerJoin("User", "ApiKey.user_id", "User.uid")
->leftJoin("Language", "User.language_id", "Language.uid")
->where(new Compare("ApiKey.api_key", $apiKey))
- ->where(new Compare("valid_until", new Keyword($this->sql->currentTimestamp()), ">"))
+ ->where(new Compare("valid_until", $this->sql->currentTimestamp(), ">"))
->where(new COmpare("ApiKey.active", 1))
->execute();
diff --git a/test/apiTest.py b/test/apiTest.py
new file mode 100644
index 0000000..3b01207
--- /dev/null
+++ b/test/apiTest.py
@@ -0,0 +1,21 @@
+import requests
+import json
+
+from phpTest import PhpTest
+
+class ApiTestCase(PhpTest):
+
+ def __init__(self):
+ super().__init__("test_api")
+ self.session = requests.Session()
+
+ def api(self, method):
+ return "%s/api/%s" % (self.url, method)
+
+ def test_api(self):
+
+ res = self.session.post(self.api("login"), data={ "username": PhpTest.ADMIN_USERNAME, "password": PhpTest.ADMIN_PASSWORD })
+ self.assertEquals(200, res.status_code, self.httpError(res))
+ self.assertEquals([], self.getPhpErrors(res))
+ obj = json.loads(res.text)
+ self.assertEquals(True, obj["success"], obj["msg"])
diff --git a/test/installTest.py b/test/installTest.py
index 724b2db..621c064 100644
--- a/test/installTest.py
+++ b/test/installTest.py
@@ -1,30 +1,13 @@
-import unittest
import requests
-import json
-import re
-import string
-import random
-class InstallTestCase(unittest.TestCase):
+from phpTest import PhpTest
+
+class InstallTestCase(PhpTest):
def __init__(self, args):
super().__init__("test_install")
self.args = args
self.session = requests.Session()
- self.url = "http://localhost/"
-
- keywords = ["Fatal error", "Warning", "Notice", "Parse error", "Deprecated"]
- self.phpPattern = re.compile("(%s):" % ("|".join(keywords)))
-
- def randomString(self, length):
- letters = string.ascii_lowercase + string.ascii_uppercase + string.digits
- return ''.join(random.choice(letters) for i in range(length))
-
- def httpError(self, res):
- return "Server returned: %d %s" % (res.status_code, res.reason)
-
- def getPhpErrors(self, res):
- return [line for line in res.text.split("\n") if self.phpPattern.search(line)]
def test_install(self):
@@ -39,46 +22,44 @@ class InstallTestCase(unittest.TestCase):
self.assertEquals([], self.getPhpErrors(res))
# Create User
- valid_username = self.randomString(16)
- valid_password = self.randomString(16)
# 1. Invalid username
for username in ["a", "a"*33]:
res = self.session.post(self.url, data={ "username": username, "password": "123456", "confirmPassword": "123456" })
self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res))
- obj = json.loads(res.text)
+ obj = self.getJson(res)
self.assertEquals(False, obj["success"])
self.assertEquals("The username should be between 5 and 32 characters long", obj["msg"])
# 2. Invalid password
- res = self.session.post(self.url, data={ "username": valid_username, "password": "1", "confirmPassword": "1" })
+ res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": "1", "confirmPassword": "1" })
self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res))
- obj = json.loads(res.text)
+ obj = self.getJson(res)
self.assertEquals(False, obj["success"])
self.assertEquals("The password should be at least 6 characters long", obj["msg"])
# 3. Passwords do not match
- res = self.session.post(self.url, data={ "username": valid_username, "password": "1", "confirmPassword": "2" })
+ res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": "1", "confirmPassword": "2" })
self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res))
- obj = json.loads(res.text)
+ obj = self.getJson(res)
self.assertEquals(False, obj["success"])
self.assertEquals("The given passwords do not match", obj["msg"])
# 4. User creation OK
- res = self.session.post(self.url, data={ "username": valid_username, "password": valid_password, "confirmPassword": valid_password })
+ res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": PhpTest.ADMIN_PASSWORD, "confirmPassword": PhpTest.ADMIN_PASSWORD })
self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res))
- obj = json.loads(res.text)
+ obj = self.getJson(res)
self.assertEquals(True, obj["success"])
# Mail: SKIP
res = self.session.post(self.url, data={ "skip": "true" })
self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res))
- obj = json.loads(res.text)
+ obj = self.getJson(res)
self.assertEquals(True, obj["success"])
# Creation successful:
diff --git a/test/phpTest.py b/test/phpTest.py
new file mode 100644
index 0000000..0c251f9
--- /dev/null
+++ b/test/phpTest.py
@@ -0,0 +1,36 @@
+import unittest
+import string
+import random
+import re
+import json
+
+class PhpTest(unittest.TestCase):
+
+ def randomString(length):
+ letters = string.ascii_lowercase + string.ascii_uppercase + string.digits
+ return ''.join(random.choice(letters) for i in range(length))
+
+ ADMIN_USERNAME = "Administrator"
+ ADMIN_PASSWORD = randomString(16)
+
+ def __init__(self, test_method):
+ super().__init__(test_method)
+ keywords = ["Fatal error", "Warning", "Notice", "Parse error", "Deprecated"]
+ self.phpPattern = re.compile("(%s):" % ("|".join(keywords)))
+ self.url = "http://localhost/"
+
+ def httpError(self, res):
+ return "Server returned: %d %s" % (res.status_code, res.reason)
+
+ def getPhpErrors(self, res):
+ return [line for line in res.text.split("\n") if self.phpPattern.search(line)]
+
+ def getJson(self, res):
+ obj = None
+ try:
+ obj = json.loads(res.text)
+ except:
+ pass
+ finally:
+ self.assertTrue(isinstance(obj, dict), res.text)
+ return obj
diff --git a/test/requirements.txt b/test/requirements.txt
index eb910f7..2134b16 100644
--- a/test/requirements.txt
+++ b/test/requirements.txt
@@ -1,2 +1,3 @@
requests==2.23.0
+psycopg2==2.8.4
mysql_connector_repackaged==0.3.1
diff --git a/test/test.py b/test/test.py
index fe61cca..50a57a1 100644
--- a/test/test.py
+++ b/test/test.py
@@ -6,9 +6,12 @@ import argparse
import random
import string
import unittest
+
import mysql.connector
+import psycopg2
from installTest import InstallTestCase
+from apiTest import ApiTestCase
CONFIG_FILES = ["../core/Configuration/Database.class.php","../core/Configuration/JWT.class.php","../core/Configuration/Mail.class.php"]
@@ -19,15 +22,13 @@ def randomName(length):
def performTest(args):
suite = unittest.TestSuite()
suite.addTest(InstallTestCase(args))
+ suite.addTest(ApiTestCase())
runner = unittest.TextTestRunner()
runner.run(suite)
def testMysql(args):
# Create a temporary database
- cursor = None
- database = None
- connection = None
if args.database is None:
args.database = "webbase_test_%s" % randomName(6)
config = {
@@ -43,6 +44,7 @@ def testMysql(args):
cursor = connection.cursor()
print("[ ] Creating temporary databse %s" % args.database)
cursor.execute("CREATE DATABASE %s" % args.database)
+ cursor.commit()
print("[+] Success")
# perform test
@@ -60,6 +62,37 @@ def testMysql(args):
print("[ ] Closing connection…")
connection.close()
+def testPostgres(args):
+
+ # Create a temporary database
+ if args.database is None:
+ args.database = "webbase_test_%s" % randomName(6)
+ connection_string = "host=%s port=%d user=%s password=%s" % (args.host, args.port, args.username, args.password)
+
+ print("[ ] Connecting to dbms…")
+ connection = psycopg2.connect(connection_string)
+ connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
+ print("[+] Success")
+ cursor = connection.cursor()
+ print("[ ] Creating temporary databse %s" % args.database)
+ cursor.execute("CREATE DATABASE %s" % args.database)
+ print("[+] Success")
+
+ # perform test
+ try:
+ args.type = "postgres"
+ performTest(args)
+ finally:
+ if cursor is not None:
+ print("[ ] Deleting temporary database")
+ cursor.execute("DROP DATABASE %s" % args.database)
+ cursor.close()
+ print("[+] Success")
+
+ if connection is not None:
+ print("[ ] Closing connection…")
+ connection.close()
+
if __name__ == "__main__":
supportedDbms = {
@@ -95,3 +128,9 @@ if __name__ == "__main__":
if args.dbms == "mysql":
testMysql(args)
+ elif args.dbms == "postgres":
+ testPostgres(args)
+
+ for f in CONFIG_FILES:
+ if os.path.isfile(f):
+ os.remove(f)