Postgres Support

This commit is contained in:
Roman Hergenreder 2020-04-02 15:08:14 +02:00
parent 067d48d2cf
commit cf726ab177
13 changed files with 288 additions and 66 deletions

@ -2,7 +2,6 @@
namespace Api; namespace Api;
use \Driver\SQL\Keyword;
use \Driver\SQL\Condition\Compare; use \Driver\SQL\Condition\Compare;
class GetApiKeys extends Request { class GetApiKeys extends Request {
@ -21,7 +20,7 @@ class GetApiKeys extends Request {
$res = $sql->select("uid", "api_key", "valid_until") $res = $sql->select("uid", "api_key", "valid_until")
->from("ApiKey") ->from("ApiKey")
->where(new Compare("user_id", $this->user->getId())) ->where(new Compare("user_id", $this->user->getId()))
->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">")) ->where(new Compare("valid_until", $sql->currentTimestamp(), ">"))
->where(new Compare("active", true)) ->where(new Compare("active", true))
->execute(); ->execute();

@ -3,7 +3,6 @@
namespace Api; namespace Api;
use \Api\Parameter\Parameter; use \Api\Parameter\Parameter;
use \Driver\SQL\Keyword;
use \Driver\SQL\Condition\Compare; use \Driver\SQL\Condition\Compare;
class RefreshApiKey extends Request { class RefreshApiKey extends Request {
@ -23,7 +22,7 @@ class RefreshApiKey extends Request {
->from("ApiKey") ->from("ApiKey")
->where(new Compare("uid", $id)) ->where(new Compare("uid", $id))
->where(new Compare("user_id", $this->user->getId())) ->where(new Compare("user_id", $this->user->getId()))
->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">")) ->where(new Compare("valid_until", $sql->currentTimestamp(), ">"))
->where(new Compare("active", 1)) ->where(new Compare("active", 1))
->execute(); ->execute();

@ -3,7 +3,6 @@
namespace Api; namespace Api;
use \Api\Parameter\Parameter; use \Api\Parameter\Parameter;
use \Driver\SQL\Keyword;
use \Driver\SQL\Condition\Compare; use \Driver\SQL\Condition\Compare;
class RevokeApiKey extends Request { class RevokeApiKey extends Request {
@ -23,7 +22,7 @@ class RevokeApiKey extends Request {
->from("ApiKey") ->from("ApiKey")
->where(new Compare("uid", $id)) ->where(new Compare("uid", $id))
->where(new Compare("user_id", $this->user->getId())) ->where(new Compare("user_id", $this->user->getId()))
->where(new Compare("valid_until", new Keyword($sql->currentTimestamp()), ">")) ->where(new Compare("valid_until", $sql->currentTimestamp(), ">"))
->where(new Compare("active", 1)) ->where(new Compare("active", 1))
->execute(); ->execute();

@ -97,7 +97,9 @@ namespace Documents\Install {
return self::DATABASE_CONFIGURATION; return self::DATABASE_CONFIGURATION;
} }
$res = $user->getSQL()->select("COUNT(*) as count")->from("User")->execute(); $sql = $user->getSQL();
$countKeyword = $sql->count();
$res = $sql->select($countKeyword)->from("User")->execute();
if ($res === FALSE) { if ($res === FALSE) {
return self::DATABASE_CONFIGURATION; return self::DATABASE_CONFIGURATION;
} else { } else {

@ -26,14 +26,13 @@ class MySQL extends SQL {
public function __construct($connectionData) { public function __construct($connectionData) {
parent::__construct($connectionData); parent::__construct($connectionData);
$this->installLink = ;
} }
public function checkRequirements() { public function checkRequirements() {
return function_exists('mysqli_connect'); return function_exists('mysqli_connect');
} }
public abstract function getDriverName() { public function getDriverName() {
return 'mysqli'; return 'mysqli';
} }
@ -248,7 +247,16 @@ class MySQL extends SQL {
public function executeSelect($select) { public function executeSelect($select) {
$columns = implode(",", $select->getColumns()); $columns = array();
foreach($select->getColumns() as $col) {
if ($col instanceof Keyword) {
$columns[] = $col->getValue();
} else {
$columns[] = "`$col`";
}
}
$columns = implode(",", $columns);
$tables = $select->getTables(); $tables = $select->getTables();
$params = array(); $params = array();
@ -364,7 +372,7 @@ class MySQL extends SQL {
if (!is_null($column->getDefaultValue()) || !$column->notNull()) { if (!is_null($column->getDefaultValue()) || !$column->notNull()) {
$defaultValue = " DEFAULT " . $this->getValueDefinition($column->getDefaultValue()); $defaultValue = " DEFAULT " . $this->getValueDefinition($column->getDefaultValue());
} }
return "`$columnName` $type$notNull$defaultValue"; return "`$columnName` $type$notNull$defaultValue";
} }
@ -416,7 +424,28 @@ class MySQL extends SQL {
} }
} }
public function currentTimestamp() { protected function tableName($table) {
return "NOW()"; return "`$table`";
} }
protected function columnName($col) {
if ($col instanceof KeyWord) {
return $col->getValue();
} else {
return "`$col`";
}
}
public function currentTimestamp() {
return new KeyWord("NOW()");
}
public function count($col = NULL) {
if (is_null($col)) {
return new Keyword("COUNT(*)");
} else {
return new Keyword("COUNT($col)");
}
}
}; };

@ -137,7 +137,7 @@ class PostgreSQL extends SQL {
// Querybuilder // Querybuilder
public function executeCreateTable($createTable) { public function executeCreateTable($createTable) {
$tableName = $createTable->getTableName(); $tableName = $this->tableName($createTable->getTableName());
$ifNotExists = $createTable->ifNotExists() ? " IF NOT EXISTS": ""; $ifNotExists = $createTable->ifNotExists() ? " IF NOT EXISTS": "";
$entries = array(); $entries = array();
@ -156,13 +156,13 @@ class PostgreSQL extends SQL {
} }
$entries = implode(",", $entries); $entries = implode(",", $entries);
$query = "CREATE TABLE$ifNotExists \"$tableName\" ($entries)"; $query = "CREATE TABLE$ifNotExists $tableName ($entries)";
return $this->execute($query); return $this->execute($query);
} }
public function executeInsert($insert) { public function executeInsert($insert) {
$tableName = $insert->getTableName(); $tableName = $this->tableName($insert->getTableName());
$columns = $insert->getColumns(); $columns = $insert->getColumns();
$rows = $insert->getRows(); $rows = $insert->getRows();
$onDuplicateKey = $insert->onDuplicateKey() ?? ""; $onDuplicateKey = $insert->onDuplicateKey() ?? "";
@ -173,11 +173,15 @@ class PostgreSQL extends SQL {
} }
if (is_null($columns) || empty($columns)) { if (is_null($columns) || empty($columns)) {
$columns = ""; $columnStr = "";
$numColumns = count($rows[0]); $numColumns = count($rows[0]);
} else { } else {
$numColumns = count($columns); $numColumns = count($columns);
$columns = " (\"" . implode("\", \"", $columns) . "\")"; $columnStr = array();
foreach($columns as $col) {
$columnStr[] = $this->columnName($col);
}
$columnStr = " (" . implode(",", $columnStr) . ")";
} }
$numRows = count($rows); $numRows = count($rows);
@ -196,7 +200,7 @@ class PostgreSQL extends SQL {
$values = implode(",", $values); $values = implode(",", $values);
if ($onDuplicateKey) { if ($onDuplicateKey) {
if ($onDuplicateKey instanceof UpdateStrategy) { /*if ($onDuplicateKey instanceof UpdateStrategy) {
$updateValues = array(); $updateValues = array();
foreach($onDuplicateKey->getValues() as $key => $value) { foreach($onDuplicateKey->getValues() as $key => $value) {
if ($value instanceof Column) { if ($value instanceof Column) {
@ -208,7 +212,7 @@ class PostgreSQL extends SQL {
} }
$onDuplicateKey = " ON CONFLICT DO UPDATE SET " . implode(",", $updateValues); $onDuplicateKey = " ON CONFLICT DO UPDATE SET " . implode(",", $updateValues);
} else { } else*/ {
$strategy = get_class($onDuplicateKey); $strategy = get_class($onDuplicateKey);
$this->lastError = "ON DUPLICATE Strategy $strategy is not supported yet."; $this->lastError = "ON DUPLICATE Strategy $strategy is not supported yet.";
return false; return false;
@ -216,9 +220,9 @@ class PostgreSQL extends SQL {
} }
$returningCol = $insert->getReturning(); $returningCol = $insert->getReturning();
$returning = $returningCol ? " RETURNING \"$returningCol\"" : ""; $returning = $returningCol ? (" RETURNING " . $this->columnName($returningCol)) : "";
$query = "INSERT INTO \"$tableName\"$columns VALUES$values$onDuplicateKey$returning"; $query = "INSERT INTO $tableName$columnStr VALUES$values$onDuplicateKey$returning";
$res = $this->execute($query, $parameters, !empty($returning)); $res = $this->execute($query, $parameters, !empty($returning));
$success = ($res !== FALSE); $success = ($res !== FALSE);
@ -229,11 +233,93 @@ class PostgreSQL extends SQL {
return $success; return $success;
} }
// TODO: public function executeSelect($select) {
public function executeSelect($query) { }
public function executeDelete($query) { } $columns = array();
public function executeTruncate($query) { } foreach($select->getColumns() as $col) {
public function executeUpdate($query) { } $columns[] = $this->columnName($col);
}
$columns = implode(",", $columns);
$tables = $select->getTables();
$params = array();
if (is_null($tables) || empty($tables)) {
return "SELECT $columns";
} else {
$tableStr = array();
foreach($tables as $table) {
$tableStr[] = $this->tableName($table);
}
$tableStr = implode(",", $tableStr);
}
$conditions = $select->getConditions();
if (!empty($conditions)) {
$condition = " WHERE " . $this->buildCondition($conditions, $params);
} else {
$condition = "";
}
$joinStr = "";
$joins = $select->getJoins();
if (!empty($joins)) {
$joinStr = "";
foreach($joins as $join) {
$type = $join->getType();
$joinTable = $this->tableName($join->getTable());
$columnA = $this->columnName($join->getColumnA());
$columnB = $this->columnName($join->getColumnB());
$joinStr .= " $type JOIN $joinTable ON $columnA=$columnB";
}
}
$orderBy = "";
$limit = "";
$offset = "";
$query = "SELECT $columns FROM $tableStr$joinStr$condition$orderBy$limit$offset";
return $this->execute($query, $params, true);
}
public function executeDelete($delete) {
$table = $delete->getTable();
$conditions = $delete->getConditions();
if (!empty($conditions)) {
$condition = " WHERE " . $this->buildCondition($conditions, $params);
} else {
$condition = "";
}
$query = "DELETE FROM \"$table\"$condition";
return $this->execute($query);
}
public function executeTruncate($truncate) {
$table = $truncate->getTable();
return $this->execute("TRUNCATE \"$table\"");
}
public function executeUpdate($update) {
$params = array();
$table = $update->getTable();
$valueStr = array();
foreach($update->getValues() as $key => $val) {
$valueStr[] = "$key=" . $this->addValue($val, $params);
}
$valueStr = implode(",", $valueStr);
$conditions = $update->getConditions();
if (!empty($conditions)) {
$condition = " WHERE " . $this->buildCondition($conditions, $params);
} else {
$condition = "";
}
$query = "UPDATE \"$table\" SET $valueStr$condition";
return $this->execute($query, $params);
}
// UGLY but.. what should i do? // UGLY but.. what should i do?
private function createEnum($enumColumn) { private function createEnum($enumColumn) {
@ -344,9 +430,36 @@ class PostgreSQL extends SQL {
} }
} }
protected function tableName($table) {
return "\"$table\"";
}
protected function columnName($col) {
if ($col instanceof KeyWord) {
return $col->getValue();
} else {
$index = strrpos($col, ".");
if ($index === FALSE) {
return "\"$col\"";
} else {
$tableName = $this->tableName(substr($col, 0, $index));
$columnName = $this->columnName(substr($col, $index + 1));
return "$tableName.$columnName";
}
}
}
// Special Keywords and functions // Special Keywords and functions
public function currentTimestamp() { public function currentTimestamp() {
return "CURRENT_TIMESTAMP"; return new Keyword("CURRENT_TIMESTAMP");
}
public function count($col = NULL) {
if (is_null($col)) {
return new Keyword("COUNT(*)");
} else {
return new Keyword("COUNT(\"$col\")");
}
} }
} }
?> ?>

@ -72,8 +72,12 @@ abstract class SQL {
protected abstract function getValueDefinition($val); protected abstract function getValueDefinition($val);
protected abstract function addValue($val, &$params); protected abstract function addValue($val, &$params);
protected abstract function tableName($table);
protected abstract function columnName($col);
// Special Keywords and functions // Special Keywords and functions
public abstract function currentTimestamp(); public abstract function currentTimestamp();
public abstract function count($col = NULL);
// Statements // Statements
protected abstract function execute($query, $values=NULL, $returnValues=false); protected abstract function execute($query, $values=NULL, $returnValues=false);
@ -86,12 +90,12 @@ abstract class SQL {
} }
return "(" . implode(" OR ", $conditions) . ")"; return "(" . implode(" OR ", $conditions) . ")";
} else if ($condition instanceof \Driver\SQL\Condition\Compare) { } else if ($condition instanceof \Driver\SQL\Condition\Compare) {
$column = $condition->getColumn(); $column = $this->columnName($condition->getColumn());
$value = $condition->getValue(); $value = $condition->getValue();
$operator = $condition->getOperator(); $operator = $condition->getOperator();
return $column . $operator . $this->addValue($value, $params); return $column . $operator . $this->addValue($value, $params);
} else if ($condition instanceof \Driver\SQL\Condition\CondBool) { } else if ($condition instanceof \Driver\SQL\Condition\CondBool) {
return $condition->getValue(); return $this->columnName($condition->getValue());
} else if (is_array($condition)) { } else if (is_array($condition)) {
if (count($condition) == 1) { if (count($condition) == 1) {
return $this->buildCondition($condition[0], $params); return $this->buildCondition($condition[0], $params);

@ -2,7 +2,6 @@
namespace Objects; namespace Objects;
use Driver\SQL\Keyword;
use Driver\SQL\Column\Column; use Driver\SQL\Column\Column;
use Driver\SQL\Condition\Compare; use Driver\SQL\Condition\Compare;
use Driver\SQL\Condition\CondBool; use Driver\SQL\Condition\CondBool;
@ -113,7 +112,7 @@ class User extends ApiObject {
->where(new Compare("User.uid", $userId)) ->where(new Compare("User.uid", $userId))
->where(new Compare("Session.uid", $sessionId)) ->where(new Compare("Session.uid", $sessionId))
->where(new Compare("Session.active", true)) ->where(new Compare("Session.active", true))
->where(new CondBool("Session.stay_logged_in"), new Compare("Session.expires", new Keyword($this->sql->currentTimestamp()), '>')) ->where(new CondBool("Session.stay_logged_in"), new Compare("Session.expires", $this->sql->currentTimestamp(), '>'))
->execute(); ->execute();
$success = ($res !== FALSE); $success = ($res !== FALSE);
@ -189,7 +188,7 @@ class User extends ApiObject {
->innerJoin("User", "ApiKey.user_id", "User.uid") ->innerJoin("User", "ApiKey.user_id", "User.uid")
->leftJoin("Language", "User.language_id", "Language.uid") ->leftJoin("Language", "User.language_id", "Language.uid")
->where(new Compare("ApiKey.api_key", $apiKey)) ->where(new Compare("ApiKey.api_key", $apiKey))
->where(new Compare("valid_until", new Keyword($this->sql->currentTimestamp()), ">")) ->where(new Compare("valid_until", $this->sql->currentTimestamp(), ">"))
->where(new COmpare("ApiKey.active", 1)) ->where(new COmpare("ApiKey.active", 1))
->execute(); ->execute();

21
test/apiTest.py Normal file

@ -0,0 +1,21 @@
import requests
import json
from phpTest import PhpTest
class ApiTestCase(PhpTest):
def __init__(self):
super().__init__("test_api")
self.session = requests.Session()
def api(self, method):
return "%s/api/%s" % (self.url, method)
def test_api(self):
res = self.session.post(self.api("login"), data={ "username": PhpTest.ADMIN_USERNAME, "password": PhpTest.ADMIN_PASSWORD })
self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res))
obj = json.loads(res.text)
self.assertEquals(True, obj["success"], obj["msg"])

@ -1,30 +1,13 @@
import unittest
import requests import requests
import json
import re
import string
import random
class InstallTestCase(unittest.TestCase): from phpTest import PhpTest
class InstallTestCase(PhpTest):
def __init__(self, args): def __init__(self, args):
super().__init__("test_install") super().__init__("test_install")
self.args = args self.args = args
self.session = requests.Session() self.session = requests.Session()
self.url = "http://localhost/"
keywords = ["Fatal error", "Warning", "Notice", "Parse error", "Deprecated"]
self.phpPattern = re.compile("<b>(%s)</b>:" % ("|".join(keywords)))
def randomString(self, length):
letters = string.ascii_lowercase + string.ascii_uppercase + string.digits
return ''.join(random.choice(letters) for i in range(length))
def httpError(self, res):
return "Server returned: %d %s" % (res.status_code, res.reason)
def getPhpErrors(self, res):
return [line for line in res.text.split("\n") if self.phpPattern.search(line)]
def test_install(self): def test_install(self):
@ -39,46 +22,44 @@ class InstallTestCase(unittest.TestCase):
self.assertEquals([], self.getPhpErrors(res)) self.assertEquals([], self.getPhpErrors(res))
# Create User # Create User
valid_username = self.randomString(16)
valid_password = self.randomString(16)
# 1. Invalid username # 1. Invalid username
for username in ["a", "a"*33]: for username in ["a", "a"*33]:
res = self.session.post(self.url, data={ "username": username, "password": "123456", "confirmPassword": "123456" }) res = self.session.post(self.url, data={ "username": username, "password": "123456", "confirmPassword": "123456" })
self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res)) self.assertEquals([], self.getPhpErrors(res))
obj = json.loads(res.text) obj = self.getJson(res)
self.assertEquals(False, obj["success"]) self.assertEquals(False, obj["success"])
self.assertEquals("The username should be between 5 and 32 characters long", obj["msg"]) self.assertEquals("The username should be between 5 and 32 characters long", obj["msg"])
# 2. Invalid password # 2. Invalid password
res = self.session.post(self.url, data={ "username": valid_username, "password": "1", "confirmPassword": "1" }) res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": "1", "confirmPassword": "1" })
self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res)) self.assertEquals([], self.getPhpErrors(res))
obj = json.loads(res.text) obj = self.getJson(res)
self.assertEquals(False, obj["success"]) self.assertEquals(False, obj["success"])
self.assertEquals("The password should be at least 6 characters long", obj["msg"]) self.assertEquals("The password should be at least 6 characters long", obj["msg"])
# 3. Passwords do not match # 3. Passwords do not match
res = self.session.post(self.url, data={ "username": valid_username, "password": "1", "confirmPassword": "2" }) res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": "1", "confirmPassword": "2" })
self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res)) self.assertEquals([], self.getPhpErrors(res))
obj = json.loads(res.text) obj = self.getJson(res)
self.assertEquals(False, obj["success"]) self.assertEquals(False, obj["success"])
self.assertEquals("The given passwords do not match", obj["msg"]) self.assertEquals("The given passwords do not match", obj["msg"])
# 4. User creation OK # 4. User creation OK
res = self.session.post(self.url, data={ "username": valid_username, "password": valid_password, "confirmPassword": valid_password }) res = self.session.post(self.url, data={ "username": PhpTest.ADMIN_USERNAME, "password": PhpTest.ADMIN_PASSWORD, "confirmPassword": PhpTest.ADMIN_PASSWORD })
self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res)) self.assertEquals([], self.getPhpErrors(res))
obj = json.loads(res.text) obj = self.getJson(res)
self.assertEquals(True, obj["success"]) self.assertEquals(True, obj["success"])
# Mail: SKIP # Mail: SKIP
res = self.session.post(self.url, data={ "skip": "true" }) res = self.session.post(self.url, data={ "skip": "true" })
self.assertEquals(200, res.status_code, self.httpError(res)) self.assertEquals(200, res.status_code, self.httpError(res))
self.assertEquals([], self.getPhpErrors(res)) self.assertEquals([], self.getPhpErrors(res))
obj = json.loads(res.text) obj = self.getJson(res)
self.assertEquals(True, obj["success"]) self.assertEquals(True, obj["success"])
# Creation successful: # Creation successful:

36
test/phpTest.py Normal file

@ -0,0 +1,36 @@
import unittest
import string
import random
import re
import json
class PhpTest(unittest.TestCase):
def randomString(length):
letters = string.ascii_lowercase + string.ascii_uppercase + string.digits
return ''.join(random.choice(letters) for i in range(length))
ADMIN_USERNAME = "Administrator"
ADMIN_PASSWORD = randomString(16)
def __init__(self, test_method):
super().__init__(test_method)
keywords = ["Fatal error", "Warning", "Notice", "Parse error", "Deprecated"]
self.phpPattern = re.compile("<b>(%s)</b>:" % ("|".join(keywords)))
self.url = "http://localhost/"
def httpError(self, res):
return "Server returned: %d %s" % (res.status_code, res.reason)
def getPhpErrors(self, res):
return [line for line in res.text.split("\n") if self.phpPattern.search(line)]
def getJson(self, res):
obj = None
try:
obj = json.loads(res.text)
except:
pass
finally:
self.assertTrue(isinstance(obj, dict), res.text)
return obj

@ -1,2 +1,3 @@
requests==2.23.0 requests==2.23.0
psycopg2==2.8.4
mysql_connector_repackaged==0.3.1 mysql_connector_repackaged==0.3.1

@ -6,9 +6,12 @@ import argparse
import random import random
import string import string
import unittest import unittest
import mysql.connector import mysql.connector
import psycopg2
from installTest import InstallTestCase from installTest import InstallTestCase
from apiTest import ApiTestCase
CONFIG_FILES = ["../core/Configuration/Database.class.php","../core/Configuration/JWT.class.php","../core/Configuration/Mail.class.php"] CONFIG_FILES = ["../core/Configuration/Database.class.php","../core/Configuration/JWT.class.php","../core/Configuration/Mail.class.php"]
@ -19,15 +22,13 @@ def randomName(length):
def performTest(args): def performTest(args):
suite = unittest.TestSuite() suite = unittest.TestSuite()
suite.addTest(InstallTestCase(args)) suite.addTest(InstallTestCase(args))
suite.addTest(ApiTestCase())
runner = unittest.TextTestRunner() runner = unittest.TextTestRunner()
runner.run(suite) runner.run(suite)
def testMysql(args): def testMysql(args):
# Create a temporary database # Create a temporary database
cursor = None
database = None
connection = None
if args.database is None: if args.database is None:
args.database = "webbase_test_%s" % randomName(6) args.database = "webbase_test_%s" % randomName(6)
config = { config = {
@ -43,6 +44,7 @@ def testMysql(args):
cursor = connection.cursor() cursor = connection.cursor()
print("[ ] Creating temporary databse %s" % args.database) print("[ ] Creating temporary databse %s" % args.database)
cursor.execute("CREATE DATABASE %s" % args.database) cursor.execute("CREATE DATABASE %s" % args.database)
cursor.commit()
print("[+] Success") print("[+] Success")
# perform test # perform test
@ -60,6 +62,37 @@ def testMysql(args):
print("[ ] Closing connection…") print("[ ] Closing connection…")
connection.close() connection.close()
def testPostgres(args):
# Create a temporary database
if args.database is None:
args.database = "webbase_test_%s" % randomName(6)
connection_string = "host=%s port=%d user=%s password=%s" % (args.host, args.port, args.username, args.password)
print("[ ] Connecting to dbms…")
connection = psycopg2.connect(connection_string)
connection.set_isolation_level(psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
print("[+] Success")
cursor = connection.cursor()
print("[ ] Creating temporary databse %s" % args.database)
cursor.execute("CREATE DATABASE %s" % args.database)
print("[+] Success")
# perform test
try:
args.type = "postgres"
performTest(args)
finally:
if cursor is not None:
print("[ ] Deleting temporary database")
cursor.execute("DROP DATABASE %s" % args.database)
cursor.close()
print("[+] Success")
if connection is not None:
print("[ ] Closing connection…")
connection.close()
if __name__ == "__main__": if __name__ == "__main__":
supportedDbms = { supportedDbms = {
@ -95,3 +128,9 @@ if __name__ == "__main__":
if args.dbms == "mysql": if args.dbms == "mysql":
testMysql(args) testMysql(args)
elif args.dbms == "postgres":
testPostgres(args)
for f in CONFIG_FILES:
if os.path.isfile(f):
os.remove(f)