1
0
mirror of https://github.com/calebstewart/pwncat.git synced 2024-11-27 19:04:15 +01:00

Incremental changes mostly moving command parser out of victim

This commit is contained in:
Caleb Stewart 2020-10-09 18:15:02 -04:00
parent f69542f0b4
commit 33003592ab
15 changed files with 351 additions and 109 deletions

View File

@ -1,8 +1,89 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from typing import Optional from typing import Optional
from io import TextIOWrapper
import sys
import os
from .config import Config import selectors
from sqlalchemy.exc import InvalidRequestError
# These need to be assigned prior to importing other
# parts of pwncat
victim: Optional["pwncat.remote.Victim"] = None victim: Optional["pwncat.remote.Victim"] = None
from .config import Config
from .commands import parser
from .util import console
from .db import get_session
config: Config = Config() config: Config = Config()
def interactive(platform):
""" Run the interactive pwncat shell with the given initialized victim.
This function handles the pwncat and remote prompts and does not return
until explicitly exited by the user.
This doesn't work yet. It's dependant on the new platform and channel
interface that isn't working yet, but it's what I'd like the eventual
interface to look like.
:param platform: an initialized platform object with a valid channel
:type platform: pwncat.platform.Platform
"""
global victim
global config
# Initialize a new victim
victim = platform
# Ensure the prompt is initialized
parser.setup_prompt()
# Ensure our stdin reference is unbuffered
sys.stdin = TextIOWrapper(
os.fdopen(sys.stdin.fileno(), "br", buffering=0),
write_through=True,
line_buffering=False,
)
# Ensure we are in raw mode
parser.raw_mode()
# Create selector for asynchronous IO
selector = selectors.DefaultSelector()
selector.register(sys.stdin, selectors.EVENT_READ, None)
selector.register(victim.channel, selectors.EVENT_READ, None)
# Main loop state
done = False
try:
while not done:
for key, _ in selector.select():
if key.fileobj is sys.stdin:
data = sys.stdin.buffer.read(64)
data = parser.parse_prefix(data)
if data:
victim.channel.send(data)
else:
data = victim.channel.recv(4096)
if data is None or not data:
done = True
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except ConnectionResetError:
console.log("[yellow]warning[/yellow]: connection reset by remote host")
except SystemExit:
console.log("closing connection")
finally:
# Ensure the terminal is back to normal
parser.restore_term()
try:
# Commit any pending changes to the database
get_session().commit()
except InvalidRequestError:
pass

View File

@ -16,6 +16,7 @@ from paramiko.buffered_pipe import BufferedPipe
import pwncat import pwncat
from pwncat.util import console from pwncat.util import console
from pwncat.remote import Victim from pwncat.remote import Victim
from pwncat.db import get_session
def main(): def main():
@ -113,7 +114,7 @@ def main():
pwncat.victim.restore_local_term() pwncat.victim.restore_local_term()
try: try:
# Make sure everything was committed # Make sure everything was committed
pwncat.victim.session.commit() get_session().commit()
except InvalidRequestError: except InvalidRequestError:
pass pass

View File

@ -22,7 +22,7 @@ class Reconnect(Channel):
host = "0.0.0.0" host = "0.0.0.0"
if port is None: if port is None:
raise ChannelError(f"no port specified") raise ChannelError("no port specified")
with Progress( with Progress(
f"bound to [blue]{host}[/blue]:[cyan]{port}[/cyan]", f"bound to [blue]{host}[/blue]:[cyan]{port}[/cyan]",

View File

@ -24,9 +24,14 @@ from prompt_toolkit.history import InMemoryHistory, History
from typing import Dict, Any, List, Iterable from typing import Dict, Any, List, Iterable
from colorama import Fore from colorama import Fore
from enum import Enum, auto from enum import Enum, auto
from io import TextIOWrapper
import argparse import argparse
import pkgutil import pkgutil
import shlex import shlex
import sys
import fcntl
import termios
import tty
import os import os
import re import re
@ -36,6 +41,7 @@ import pwncat
import pwncat.db import pwncat.db
from pwncat.commands.base import CommandDefinition, Complete from pwncat.commands.base import CommandDefinition, Complete
from pwncat.util import State, console from pwncat.util import State, console
from pwncat.db import get_session
def resolve_blocks(source: str): def resolve_blocks(source: str):
@ -101,7 +107,8 @@ class DatabaseHistory(History):
def load_history_strings(self) -> Iterable[str]: def load_history_strings(self) -> Iterable[str]:
""" Load the history from the database """ """ Load the history from the database """
for history in ( for history in (
pwncat.victim.session.query(pwncat.db.History) get_session()
.query(pwncat.db.History)
.order_by(pwncat.db.History.id.desc()) .order_by(pwncat.db.History.id.desc())
.all() .all()
): ):
@ -110,12 +117,15 @@ class DatabaseHistory(History):
def store_string(self, string: str) -> None: def store_string(self, string: str) -> None:
""" Store a command in the database """ """ Store a command in the database """
history = pwncat.db.History(host_id=pwncat.victim.host.id, command=string) history = pwncat.db.History(host_id=pwncat.victim.host.id, command=string)
pwncat.victim.session.add(history) get_session().add(history)
class CommandParser: class CommandParser:
""" Handles dynamically loading command classes, parsing input, and """ Handles dynamically loading command classes, parsing input, and
dispatching commands. """ dispatching commands. This class effectively has complete control over
the terminal whenever in an interactive pwncat session. It will change
termios modes for the control tty at will in order to support raw vs
command mode. """
def __init__(self): def __init__(self):
""" We need to dynamically load commands from pwncat.commands """ """ We need to dynamically load commands from pwncat.commands """
@ -134,6 +144,10 @@ class CommandParser:
self.loading_complete = False self.loading_complete = False
self.aliases: Dict[str, CommandDefinition] = {} self.aliases: Dict[str, CommandDefinition] = {}
self.shortcuts: Dict[str, CommandDefinition] = {} self.shortcuts: Dict[str, CommandDefinition] = {}
self.found_prefix: bool = False
# Saved terminal state to support switching between raw and normal
# mode.
self.saved_term_state = None
def setup_prompt(self): def setup_prompt(self):
""" This needs to happen after __init__ when the database is fully """ This needs to happen after __init__ when the database is fully
@ -210,10 +224,7 @@ class CommandParser:
if pwncat.config.module: if pwncat.config.module:
self.prompt.message = [ self.prompt.message = [
( ("fg:ansiyellow bold", f"({pwncat.config.module.name}) ",),
"fg:ansiyellow bold",
f"({pwncat.config.module.name}) ",
),
("fg:ansimagenta bold", "pwncat"), ("fg:ansimagenta bold", "pwncat"),
("", "$ "), ("", "$ "),
] ]
@ -315,6 +326,113 @@ class CommandParser:
# The arguments were incorrect # The arguments were incorrect
return return
def parse_prefix(self, channel, data: bytes):
""" Parse data received from the user when in pwncat's raw mode.
This will intercept key presses from the user and interpret the
prefix and any bound keyboard shortcuts. It also sends any data
without a prefix to the remote channel.
:param data: input data from user
:type data: bytes
"""
buffer = b""
for c in data:
if not self.found_prefix and c != pwncat.config["prefix"].value:
buffer += c
continue
elif not self.found_prefix and c == pwncat.config["prefix"].value:
self.found_prefix = True
channel.send(buffer)
buffer = b""
continue
elif self.found_prefix:
try:
binding = pwncat.config.binding(c)
if binding.strip() == "pass":
buffer += c
else:
# Restore the normal terminal
self.restore_term()
# Run the binding script
self.eval(binding, "<binding>")
# Drain any channel output
channel.drain()
channel.send(b"\n")
# Go back to a raw terminal
self.raw_mode()
except KeyError:
pass
self.found_prefix = False
# Flush any remaining raw data bound for the victim
channel.send(buffer)
def raw_mode(self):
""" Save the current terminal state and enter raw mode.
If the terminal is already in raw mode, this function
does nothing. """
if self.saved_term_state is not None:
return
# Ensure we don't have any weird buffering issues
sys.stdout.flush()
# Python doesn't provide a way to use setvbuf, so we reopen stdout
# and specify no buffering. Duplicating stdin allows the user to press C-d
# at the local prompt, and still be able to return to the remote prompt.
try:
os.dup2(sys.stdin.fileno(), sys.stdout.fileno())
except OSError:
pass
sys.stdout = TextIOWrapper(
os.fdopen(os.dup(sys.stdin.fileno()), "bw", buffering=0),
write_through=True,
line_buffering=False,
)
# Grab and duplicate current attributes
fild = sys.stdin.fileno()
old = termios.tcgetattr(fild)
new = termios.tcgetattr(fild)
# Remove ECHO from lflag and ensure we won't block
new[3] &= ~(termios.ECHO | termios.ICANON)
new[6][termios.VMIN] = 0
new[6][termios.VTIME] = 0
termios.tcsetattr(fild, termios.TCSADRAIN, new)
# Set raw mode
tty.setraw(sys.stdin)
orig_fl = fcntl.fcntl(sys.stdin, fcntl.F_GETFL)
fcntl.fcntl(sys.stdin, fcntl.F_SETFL, orig_fl)
self.saved_term_state = old, orig_fl
def restore_term(self, new_line=True):
""" Restores the normal terminal settings. This does nothing if the
terminal is not currently in raw mode. """
if self.saved_term_state is None:
return
termios.tcsetattr(
sys.stdin.fileno(), termios.TCSADRAIN, self.saved_term_state[0]
)
# tty.setcbreak(sys.stdin)
fcntl.fcntl(sys.stdin, fcntl.F_SETFL, self.saved_term_state[1])
if new_line:
sys.stdout.write("\n")
self.saved_term_state = None
class CommandLexer(RegexLexer): class CommandLexer(RegexLexer):
@ -537,3 +655,9 @@ class CommandCompleter(Completer):
yield from next_completer.get_completions(document, complete_event) yield from next_completer.get_completions(document, complete_event)
elif this_completer is not None: elif this_completer is not None:
yield from this_completer.get_completions(document, complete_event) yield from this_completer.get_completions(document, complete_event)
# Here, we allocate the global parser object and initialize in-memory
# settings
parser: CommandParser = CommandParser()
parser.setup_prompt()

View File

@ -11,6 +11,7 @@ from pwncat.commands.base import (
StoreForAction, StoreForAction,
) )
from pwncat.util import console from pwncat.util import console
from pwncat.db import get_session
class Command(CommandDefinition): class Command(CommandDefinition):
@ -71,10 +72,14 @@ class Command(CommandDefinition):
return return
# Find all binaries which are provided by busybox # Find all binaries which are provided by busybox
provides = pwncat.victim.session.query(pwncat.db.Binary).filter( provides = (
get_session()
.query(pwncat.db.Binary)
.filter(
pwncat.db.Binary.path.contains(pwncat.victim.host.busybox), pwncat.db.Binary.path.contains(pwncat.victim.host.busybox),
pwncat.db.Binary.host_id == pwncat.victim.host.id, pwncat.db.Binary.host_id == pwncat.victim.host.id,
) )
)
for binary in provides: for binary in provides:
console.print(f" - {binary.name}") console.print(f" - {binary.name}")
@ -88,7 +93,8 @@ class Command(CommandDefinition):
# Find all binaries which are provided from busybox # Find all binaries which are provided from busybox
nprovides = ( nprovides = (
pwncat.victim.session.query(pwncat.db.Binary) get_session()
.query(pwncat.db.Binary)
.filter( .filter(
pwncat.db.Binary.path.contains(pwncat.victim.host.busybox), pwncat.db.Binary.path.contains(pwncat.victim.host.busybox),
pwncat.db.Binary.host_id == pwncat.victim.host.id, pwncat.db.Binary.host_id == pwncat.victim.host.id,

View File

@ -21,6 +21,7 @@ from pwncat.commands.base import (
# from pwncat.persist import PersistenceError # from pwncat.persist import PersistenceError
from pwncat.modules.persist import PersistError from pwncat.modules.persist import PersistError
from pwncat.db import get_session
class Command(CommandDefinition): class Command(CommandDefinition):
@ -119,7 +120,7 @@ class Command(CommandDefinition):
# persistence methods # persistence methods
hosts = { hosts = {
host.hash: (host, []) host.hash: (host, [])
for host in pwncat.victim.session.query(pwncat.db.Host).all() for host in get_session().query(pwncat.db.Host).all()
} }
for module in modules: for module in modules:
@ -201,9 +202,7 @@ class Command(CommandDefinition):
try: try:
addr = ipaddress.ip_address(socket.gethostbyname(host)) addr = ipaddress.ip_address(socket.gethostbyname(host))
row = ( row = (
pwncat.victim.session.query(pwncat.db.Host) get_session().query(pwncat.db.Host).filter_by(ip=str(addr)).first()
.filter_by(ip=str(addr))
.first()
) )
if row is None: if row is None:
console.log(f"{level}: {str(addr)}: not found in database") console.log(f"{level}: {str(addr)}: not found in database")

View File

@ -6,15 +6,14 @@ from sqlalchemy.orm import sessionmaker
import pwncat import pwncat
from pwncat.commands.base import CommandDefinition, Complete, Parameter from pwncat.commands.base import CommandDefinition, Complete, Parameter
from pwncat.util import console, State from pwncat.util import console, State
from pwncat.db import get_session, reset_engine
class Command(CommandDefinition): class Command(CommandDefinition):
""" Set variable runtime variable parameters for pwncat """ """ Set variable runtime variable parameters for pwncat """
def get_config_variables(self): def get_config_variables(self):
options = ( options = ["state"] + list(pwncat.config.values) + list(pwncat.victim.users)
["state"] + list(pwncat.config.values) + list(pwncat.victim.users)
)
if pwncat.config.module: if pwncat.config.module:
options.extend(pwncat.config.module.ARGUMENTS.keys()) options.extend(pwncat.config.module.ARGUMENTS.keys())
@ -82,16 +81,14 @@ class Command(CommandDefinition):
if args.variable == "db": if args.variable == "db":
# We handle this specially to ensure the database is available # We handle this specially to ensure the database is available
# as soon as this config is set # as soon as this config is set
pwncat.victim.engine = create_engine( reset_engine()
pwncat.config["db"], echo=False if pwncat.victim.host is not None:
pwncat.victim.host = (
get_session()
.query(pwncat.db.Host)
.filter_by(id=pwncat.victim.host.id)
.scalar()
) )
pwncat.db.Base.metadata.create_all(pwncat.victim.engine)
# Create the session_maker and default session
pwncat.victim.session_maker = sessionmaker(
bind=pwncat.victim.engine
)
pwncat.victim.session = pwncat.victim.session_maker()
except ValueError as exc: except ValueError as exc:
console.log(f"[red]error[/red]: {exc}") console.log(f"[red]error[/red]: {exc}")
elif args.variable is not None: elif args.variable is not None:

View File

@ -9,7 +9,6 @@ from prompt_toolkit.input.ansi_escape_sequences import (
ANSI_SEQUENCES, ANSI_SEQUENCES,
) )
from prompt_toolkit.keys import ALL_KEYS, Keys from prompt_toolkit.keys import ALL_KEYS, Keys
import commentjson as json
from pwncat.modules import BaseModule from pwncat.modules import BaseModule

View File

@ -1,5 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from sqlalchemy.engine import Engine, create_engine
from sqlalchemy.orm import Session, sessionmaker
import pwncat
from pwncat.db.base import Base from pwncat.db.base import Base
from pwncat.db.binary import Binary from pwncat.db.binary import Binary
from pwncat.db.history import History from pwncat.db.history import History
@ -9,3 +13,53 @@ from pwncat.db.suid import SUID
from pwncat.db.tamper import Tamper from pwncat.db.tamper import Tamper
from pwncat.db.user import User, Group, SecondaryGroupAssociation from pwncat.db.user import User, Group, SecondaryGroupAssociation
from pwncat.db.fact import Fact from pwncat.db.fact import Fact
ENGINE: Engine = None
SESSION_MAKER = None
SESSION: Session = None
def get_engine() -> Engine:
"""
Get a copy of the database engine
"""
global ENGINE
if ENGINE is not None:
return ENGINE
ENGINE = create_engine(pwncat.config["db"], echo=False)
Base.metadata.create_all(ENGINE)
return ENGINE
def get_session() -> Session:
"""
Get a new session object
"""
global SESSION_MAKER
global SESSION
if SESSION_MAKER is None:
SESSION_MAKER = sessionmaker(bind=get_engine())
if SESSION is None:
SESSION = SESSION_MAKER()
return SESSION
def reset_engine():
"""
Reload the engine and session
"""
global ENGINE
global SESSION
global SESSION_MAKER
ENGINE = None
SESSION = None
SESSION_MAKER = None

View File

@ -8,6 +8,7 @@ import sqlalchemy
import pwncat import pwncat
from pwncat.platform import Platform from pwncat.platform import Platform
from pwncat.modules import BaseModule, Status, Argument, List from pwncat.modules import BaseModule, Status, Argument, List
from pwncat.db import get_session
class Schedule(Enum): class Schedule(Enum):
@ -61,14 +62,17 @@ class EnumerateModule(BaseModule):
if clear: if clear:
# Delete enumerated facts # Delete enumerated facts
query = pwncat.victim.session.query(pwncat.db.Fact).filter_by( query = (
source=self.name, host_id=pwncat.victim.host.id get_session()
.query(pwncat.db.Fact)
.filter_by(source=self.name, host_id=pwncat.victim.host.id)
) )
query.delete(synchronize_session=False) query.delete(synchronize_session=False)
# Delete our marker # Delete our marker
if self.SCHEDULE != Schedule.ALWAYS: if self.SCHEDULE != Schedule.ALWAYS:
query = ( query = (
pwncat.victim.session.query(pwncat.db.Fact) get_session()
.query(pwncat.db.Fact)
.filter_by(host_id=pwncat.victim.host.id, type="marker") .filter_by(host_id=pwncat.victim.host.id, type="marker")
.filter(pwncat.db.Fact.source.startswith(self.name)) .filter(pwncat.db.Fact.source.startswith(self.name))
) )
@ -77,7 +81,8 @@ class EnumerateModule(BaseModule):
# Yield all the know facts which have already been enumerated # Yield all the know facts which have already been enumerated
existing_facts = ( existing_facts = (
pwncat.victim.session.query(pwncat.db.Fact) get_session()
.query(pwncat.db.Fact)
.filter_by(source=self.name, host_id=pwncat.victim.host.id) .filter_by(source=self.name, host_id=pwncat.victim.host.id)
.filter(pwncat.db.Fact.type != "marker") .filter(pwncat.db.Fact.type != "marker")
) )
@ -92,7 +97,8 @@ class EnumerateModule(BaseModule):
if self.SCHEDULE != Schedule.ALWAYS: if self.SCHEDULE != Schedule.ALWAYS:
exists = ( exists = (
pwncat.victim.session.query(pwncat.db.Fact.id) get_session()
.query(pwncat.db.Fact.id)
.filter_by( .filter_by(
host_id=pwncat.victim.host.id, type="marker", source=marker_name host_id=pwncat.victim.host.id, type="marker", source=marker_name
) )
@ -114,11 +120,11 @@ class EnumerateModule(BaseModule):
host_id=pwncat.victim.host.id, type=typ, data=data, source=self.name host_id=pwncat.victim.host.id, type=typ, data=data, source=self.name
) )
try: try:
pwncat.victim.session.add(row) get_session().add(row)
pwncat.victim.host.facts.append(row) pwncat.victim.host.facts.append(row)
pwncat.victim.session.commit() get_session().commit()
except sqlalchemy.exc.IntegrityError: except sqlalchemy.exc.IntegrityError:
pwncat.victim.session.rollback() get_session().rollback()
yield Status(data) yield Status(data)
continue continue
@ -140,7 +146,7 @@ class EnumerateModule(BaseModule):
source=marker_name, source=marker_name,
data=None, data=None,
) )
pwncat.victim.session.add(row) get_session().add(row)
pwncat.victim.host.facts.append(row) pwncat.victim.host.facts.append(row)
def enumerate(self): def enumerate(self):

View File

@ -13,6 +13,7 @@ import pwncat.modules
from pwncat import util from pwncat import util
from pwncat.util import console from pwncat.util import console
from pwncat.modules.enumerate import EnumerateModule from pwncat.modules.enumerate import EnumerateModule
from pwncat.db import get_session
def strip_markup(styled_text: str) -> str: def strip_markup(styled_text: str) -> str:
@ -86,7 +87,7 @@ class Module(pwncat.modules.BaseModule):
for module in modules: for module in modules:
yield pwncat.modules.Status(module.name) yield pwncat.modules.Status(module.name)
module.run(progress=self.progress, clear=True) module.run(progress=self.progress, clear=True)
pwncat.victim.session.commit() get_session().commit()
pwncat.victim.reload_host() pwncat.victim.reload_host()
return return

View File

@ -15,6 +15,7 @@ from pwncat.modules import (
PersistError, PersistError,
PersistType, PersistType,
) )
from pwncat.db import get_session
class PersistModule(BaseModule): class PersistModule(BaseModule):
@ -95,7 +96,8 @@ class PersistModule(BaseModule):
# Check if this module has been installed with the same arguments before # Check if this module has been installed with the same arguments before
ident = ( ident = (
pwncat.victim.session.query(pwncat.db.Persistence.id) get_session()
.query(pwncat.db.Persistence.id)
.filter_by(host_id=pwncat.victim.host.id, method=self.name, args=kwargs) .filter_by(host_id=pwncat.victim.host.id, method=self.name, args=kwargs)
.scalar() .scalar()
) )
@ -110,7 +112,7 @@ class PersistModule(BaseModule):
yield result yield result
# Remove from the database # Remove from the database
pwncat.victim.session.query(pwncat.db.Persistence).filter_by( get_session().query(pwncat.db.Persistence).filter_by(
host_id=pwncat.victim.host.id, method=self.name, args=kwargs host_id=pwncat.victim.host.id, method=self.name, args=kwargs
).delete(synchronize_session=False) ).delete(synchronize_session=False)
return return
@ -178,7 +180,7 @@ class PersistModule(BaseModule):
) )
pwncat.victim.host.persistence.append(row) pwncat.victim.host.persistence.append(row)
pwncat.victim.session.commit() get_session().commit()
def install(self, **kwargs): def install(self, **kwargs):
""" """

View File

@ -5,6 +5,7 @@ import pwncat
from pwncat.util import console from pwncat.util import console
from pwncat.modules import BaseModule, Argument, Status, Bool, Result from pwncat.modules import BaseModule, Argument, Status, Bool, Result
import pwncat.modules.persist import pwncat.modules.persist
from pwncat.db import get_session
@dataclasses.dataclass @dataclasses.dataclass
@ -93,11 +94,13 @@ class Module(BaseModule):
""" Execute this module """ """ Execute this module """
if pwncat.victim.host is not None: if pwncat.victim.host is not None:
query = pwncat.victim.session.query(pwncat.db.Persistence).filter_by( query = (
host_id=pwncat.victim.host.id get_session()
.query(pwncat.db.Persistence)
.filter_by(host_id=pwncat.victim.host.id)
) )
else: else:
query = pwncat.victim.session.query(pwncat.db.Persistence) query = get_session().query(pwncat.db.Persistence)
if module is not None: if module is not None:
query = query.filter_by(method=module) query = query.filter_by(method=module)

View File

@ -35,6 +35,7 @@ from pwncat.remote import RemoteService
from pwncat.tamper import TamperManager from pwncat.tamper import TamperManager
from pwncat.util import State, console from pwncat.util import State, console
from pwncat.modules.persist import PersistError, PersistType from pwncat.modules.persist import PersistError, PersistType
from pwncat.db import get_session
def remove_busybox_tamper(): def remove_busybox_tamper():
@ -135,29 +136,12 @@ class Victim:
self.client: Optional[socket.SocketType] = None self.client: Optional[socket.SocketType] = None
# The shell we are running under on the remote host # The shell we are running under on the remote host
self.shell: str = "unknown" self.shell: str = "unknown"
# Database engine
self.engine: Engine = None
# Database session
self.session: Session = None
# The host object as seen by the database # The host object as seen by the database
self.host: pwncat.db.Host = None self.host: pwncat.db.Host = None
# The current user. This is cached while at the `pwncat` prompt # The current user. This is cached while at the `pwncat` prompt
# and reloaded whenever returning from RAW mode. # and reloaded whenever returning from RAW mode.
self.cached_user: str = None self.cached_user: str = None
# The db engine is created here, but likely wrong. This happens
# before a configuration script is loaded, so likely creates a
# in memory db. This needs to happen because other parts of the
# framework assume a db engine exists, and therefore needs this
# reference. Also, in the case a config isn't loaded this
# needs to happen.
self.engine = create_engine(pwncat.config["db"], echo=False)
pwncat.db.Base.metadata.create_all(self.engine)
# Create the session_maker and default session
self.session_maker = sessionmaker(bind=self.engine)
self.session = self.session_maker()
def reconnect( def reconnect(
self, hostid: str, requested_method: str = None, requested_user: str = None self, hostid: str, requested_method: str = None, requested_user: str = None
): ):
@ -176,17 +160,8 @@ class Victim:
will be tried. will be tried.
""" """
# Create the database engine, and then create the schema
# if needed.
self.engine = create_engine(pwncat.config["db"], echo=False)
pwncat.db.Base.metadata.create_all(self.engine)
# Create the session_maker and default session
self.session_maker = sessionmaker(bind=self.engine)
self.session = self.session_maker()
# Load this host from the database # Load this host from the database
self.host = self.session.query(pwncat.db.Host).filter_by(hash=hostid).first() self.host = get_session().query(pwncat.db.Host).filter_by(hash=hostid).first()
if self.host is None: if self.host is None:
raise PersistError(f"invalid host hash") raise PersistError(f"invalid host hash")
@ -235,17 +210,6 @@ class Victim:
:return: None :return: None
""" """
# Create the database engine, and then create the schema
# if needed.
if self.engine is None:
self.engine = create_engine(pwncat.config["db"], echo=False)
pwncat.db.Base.metadata.create_all(self.engine)
# Create the session_maker and default session
if self.session is None:
self.session_maker = sessionmaker(bind=self.engine)
self.session = self.session_maker()
# Initialize the socket connection # Initialize the socket connection
self.client = client self.client = client
@ -315,7 +279,7 @@ class Victim:
# Lookup the remote host in our database. If it's not there, create an entry # Lookup the remote host in our database. If it's not there, create an entry
self.host = ( self.host = (
self.session.query(pwncat.db.Host).filter_by(hash=host_hash).first() get_session().query(pwncat.db.Host).filter_by(hash=host_hash).first()
) )
if self.host is None: if self.host is None:
progress.log(f"new host w/ hash [cyan]{host_hash}[/cyan]") progress.log(f"new host w/ hash [cyan]{host_hash}[/cyan]")
@ -324,9 +288,9 @@ class Victim:
# Probe for system information # Probe for system information
self.probe_host_details(progress, task_id) self.probe_host_details(progress, task_id)
# Add the host to the session # Add the host to the session
self.session.add(self.host) get_session().add(self.host)
# Commit what we know # Commit what we know
self.session.commit() get_session().commit()
# Save the remote host IP address # Save the remote host IP address
self.host.ip = self.client.getpeername()[0] self.host.ip = self.client.getpeername()[0]
@ -554,16 +518,17 @@ class Victim:
# Replace anything we provide in our binary cache with the busybox version # Replace anything we provide in our binary cache with the busybox version
for name in provides: for name in provides:
binary = ( binary = (
self.session.query(pwncat.db.Binary) get_session()
.query(pwncat.db.Binary)
.filter_by(host_id=self.host.id, name=name) .filter_by(host_id=self.host.id, name=name)
.first() .first()
) )
if binary is not None: if binary is not None:
self.session.delete(binary) get_session().delete(binary)
binary = pwncat.db.Binary(name=name, path=f"{busybox_remote_path} {name}") binary = pwncat.db.Binary(name=name, path=f"{busybox_remote_path} {name}")
self.host.binaries.append(binary) self.host.binaries.append(binary)
self.session.commit() get_session().commit()
console.log(f"busybox installed w/ {len(provides)} applets") console.log(f"busybox installed w/ {len(provides)} applets")
@ -572,7 +537,7 @@ class Victim:
operations such as clearing enumeration data. """ operations such as clearing enumeration data. """
self.host = ( self.host = (
self.session.query(pwncat.db.Host).filter_by(id=self.host.id).first() get_session().query(pwncat.db.Host).filter_by(id=self.host.id).first()
) )
def probe_host_details(self, progress: Progress, task_id): def probe_host_details(self, progress: Progress, task_id):
@ -633,7 +598,7 @@ class Victim:
for binary in self.host.binaries: for binary in self.host.binaries:
if self.host.busybox in binary.path: if self.host.busybox in binary.path:
self.session.delete(binary) get_session().delete(binary)
# Did we upload a copy of busybox or was it already installed? # Did we upload a copy of busybox or was it already installed?
if self.host.busybox_uploaded: if self.host.busybox_uploaded:
@ -663,7 +628,8 @@ class Victim:
""" """
binary = ( binary = (
self.session.query(pwncat.db.Binary) get_session()
.query(pwncat.db.Binary)
.filter_by(name=name, host_id=self.host.id) .filter_by(name=name, host_id=self.host.id)
.first() .first()
) )
@ -2067,7 +2033,8 @@ class Victim:
continue continue
line = line.strip().split(":") line = line.strip().split(":")
user = ( user = (
self.session.query(pwncat.db.User) get_session()
.query(pwncat.db.User)
.filter_by(host_id=self.host.id, id=int(line[2]), name=line[0]) .filter_by(host_id=self.host.id, id=int(line[2]), name=line[0])
.first() .first()
) )
@ -2086,7 +2053,7 @@ class Victim:
# Remove users that don't exist anymore # Remove users that don't exist anymore
for user in self.host.users: for user in self.host.users:
if user.name not in current_users: if user.name not in current_users:
self.session.delete(user) get_session().delete(user)
self.host.users.remove(user) self.host.users.remove(user)
with self.open("/etc/group", "r") as filp: with self.open("/etc/group", "r") as filp:
@ -2097,7 +2064,8 @@ class Victim:
line = line.split(":") line = line.split(":")
group = ( group = (
self.session.query(pwncat.db.Group) get_session()
.query(pwncat.db.Group)
.filter_by(host_id=self.host.id, id=int(line[2])) .filter_by(host_id=self.host.id, id=int(line[2]))
.first() .first()
) )
@ -2111,7 +2079,8 @@ class Victim:
for username in line[3].split(","): for username in line[3].split(","):
user = ( user = (
self.session.query(pwncat.db.User) get_session()
.query(pwncat.db.User)
.filter_by(host_id=self.host.id, name=username) .filter_by(host_id=self.host.id, name=username)
.first() .first()
) )
@ -2131,7 +2100,8 @@ class Victim:
continue continue
user = ( user = (
self.session.query(pwncat.db.User) get_session()
.query(pwncat.db.User)
.filter_by(host_id=self.host.id, name=entries[0]) .filter_by(host_id=self.host.id, name=entries[0])
.first() .first()
) )
@ -2146,7 +2116,7 @@ class Victim:
# Reload the host object # Reload the host object
self.host = ( self.host = (
self.session.query(pwncat.db.Host).filter_by(id=self.host.id).first() get_session().query(pwncat.db.Host).filter_by(id=self.host.id).first()
) )
return self.users return self.users

View File

@ -6,6 +6,7 @@ from colorama import Fore
import pwncat import pwncat
from pwncat.util import Access from pwncat.util import Access
from pwncat.db import get_session
class Action(Enum): class Action(Enum):
@ -147,7 +148,7 @@ class TamperManager:
serialized = pickle.dumps(tamper) serialized = pickle.dumps(tamper)
tracker = pwncat.db.Tamper(name=str(tamper), data=serialized) tracker = pwncat.db.Tamper(name=str(tamper), data=serialized)
pwncat.victim.host.tampers.append(tracker) pwncat.victim.host.tampers.append(tracker)
pwncat.victim.session.commit() get_session().commit()
def custom(self, name: str, revert: Optional[Callable] = None): def custom(self, name: str, revert: Optional[Callable] = None):
tamper = LambdaTamper(name, revert) tamper = LambdaTamper(name, revert)
@ -176,10 +177,8 @@ class TamperManager:
It removes the tracking for this tamper. """ It removes the tracking for this tamper. """
tracker = ( tracker = (
pwncat.victim.session.query(pwncat.db.Tamper) get_session().query(pwncat.db.Tamper).filter_by(name=str(tamper)).first()
.filter_by(name=str(tamper))
.first()
) )
if tracker is not None: if tracker is not None:
pwncat.victim.session.delete(tracker) get_session().delete(tracker)
pwncat.victim.session.commit() get_session().commit()