update
This commit is contained in:
@@ -5,6 +5,7 @@ from urllib.parse import urlparse
|
||||
import threading
|
||||
import requests
|
||||
import sys
|
||||
import time
|
||||
import os
|
||||
import ssl
|
||||
import util
|
||||
@@ -82,6 +83,9 @@ class FileServerRequestHandler(BaseHTTPRequestHandler):
|
||||
data = b"" if len(result) < 2 else result[1]
|
||||
headers = { } if len(result) < 3 else result[2]
|
||||
|
||||
if path in self.server.dumpRequests:
|
||||
headers["Access-Control-Allow-Origin"] = "*"
|
||||
|
||||
if len(headers) == 0:
|
||||
self.send_response(status_code)
|
||||
else:
|
||||
@@ -128,6 +132,8 @@ class HttpFileServer(HTTPServer):
|
||||
self.dumpRequests = []
|
||||
self.prefix_routes = { }
|
||||
self.is_running = True
|
||||
self.listen_thread = None
|
||||
self.has_exited = False
|
||||
|
||||
def cleanPath(self, path):
|
||||
|
||||
@@ -164,12 +170,14 @@ class HttpFileServer(HTTPServer):
|
||||
def enableLogging(self):
|
||||
self.logRequests = True
|
||||
|
||||
def enableSSL(self, keyFile=None, certFile=None):
|
||||
if keyFile is None:
|
||||
print("Generating certificate…")
|
||||
def enableSSL(self, keyFile="private.key", certFile="server.crt"):
|
||||
|
||||
if not os.path.isfile(keyFile):
|
||||
print("Generating private key and certificate…")
|
||||
os.system("openssl req -new -x509 -keyout private.key -out server.crt -days 365 -nodes")
|
||||
certFile = "server.crt"
|
||||
keyFile = "private.key"
|
||||
elif not os.path.isfile(certFile):
|
||||
print("Generating certificate…")
|
||||
os.system("openssl req -new -x509 -keyin private.key -out server.crt -days 365 -nodes")
|
||||
|
||||
self.socket = ssl.wrap_socket(self.socket,
|
||||
server_side=True,
|
||||
@@ -185,21 +193,42 @@ class HttpFileServer(HTTPServer):
|
||||
# pass
|
||||
|
||||
def startBackground(self):
|
||||
t = threading.Thread(target=self.serve_forever)
|
||||
t.start()
|
||||
return t
|
||||
self.listen_thread = threading.Thread(target=self.serve_forever)
|
||||
self.listen_thread.start()
|
||||
return self.listen_thread
|
||||
|
||||
def start(self):
|
||||
return self.serve_forever()
|
||||
|
||||
def get_base_url():
|
||||
addr, port = self.server_address
|
||||
if port != 80:
|
||||
port = f":{port}"
|
||||
protocol = "https" if gettype(self.socket) == ssl.SSLSocket else "http"
|
||||
return f"{protocol}://{addr}{port}"
|
||||
|
||||
def stop(self):
|
||||
self.is_running = False
|
||||
# dummy request
|
||||
requests.get(f"http://{self.server_name}:{self.server_port}/dummy")
|
||||
time.sleep(1)
|
||||
|
||||
try:
|
||||
# dummy request
|
||||
for i in range(3):
|
||||
requests.get(f"{self.get_base_url()}/dummy")
|
||||
if self.has_exited:
|
||||
break
|
||||
time.sleep(1)
|
||||
except:
|
||||
pass
|
||||
|
||||
if self.listen_thread != threading.currentThread():
|
||||
self.listen_thread.join()
|
||||
|
||||
def serve_forever(self):
|
||||
self.has_exited = False
|
||||
while self.is_running:
|
||||
self.handle_request()
|
||||
self.has_exited = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user