postgres version
This commit is contained in:
95
server/submit_loop.py
Executable file
95
server/submit_loop.py
Executable file
@@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import importlib
|
||||
import random
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from __init__ import app
|
||||
import database
|
||||
import reloader
|
||||
from models import Flag, FlagStatus, SubmitResult
|
||||
|
||||
|
||||
def get_fair_share(groups, limit):
|
||||
if not groups:
|
||||
return []
|
||||
|
||||
groups = sorted(groups, key=len)
|
||||
places_left = limit
|
||||
group_count = len(groups)
|
||||
fair_share = places_left // group_count
|
||||
|
||||
result = []
|
||||
residuals = []
|
||||
for group in groups:
|
||||
if len(group) <= fair_share:
|
||||
result += group
|
||||
|
||||
places_left -= len(group)
|
||||
group_count -= 1
|
||||
if group_count > 0:
|
||||
fair_share = places_left // group_count
|
||||
# The fair share could have increased because the processed group
|
||||
# had a few elements. Sorting order guarantees that the smaller
|
||||
# groups will be processed first.
|
||||
else:
|
||||
selected = random.sample(group, fair_share + 1)
|
||||
result += selected[:-1]
|
||||
residuals.append(selected[-1])
|
||||
result += random.sample(residuals, min(limit - len(result), len(residuals)))
|
||||
|
||||
random.shuffle(result)
|
||||
return result
|
||||
|
||||
|
||||
def submit_flags(flags, config):
|
||||
module = importlib.import_module('protocols.' + config['SYSTEM_PROTOCOL'])
|
||||
|
||||
try:
|
||||
return list(module.submit_flags(flags, config))
|
||||
except Exception as e:
|
||||
message = '{}: {}'.format(type(e).__name__, str(e))
|
||||
app.logger.exception('Exception on submitting flags')
|
||||
return [SubmitResult(item.flag, FlagStatus.QUEUED, message) for item in flags]
|
||||
|
||||
|
||||
def run_loop():
|
||||
app.logger.info('Starting submit loop')
|
||||
with app.app_context():
|
||||
db = database.get(context_bound=False)
|
||||
cursor = db.cursor()
|
||||
while True:
|
||||
submit_start_time = time.time()
|
||||
|
||||
config = reloader.get_config()
|
||||
|
||||
skip_time = round(submit_start_time - config['FLAG_LIFETIME'])
|
||||
cursor.execute("UPDATE flags SET status = %s WHERE status = %s AND time < %s",
|
||||
(FlagStatus.SKIPPED.name, FlagStatus.QUEUED.name, skip_time))
|
||||
db.commit()
|
||||
|
||||
cursor.execute("SELECT * FROM flags WHERE status = %s", (FlagStatus.QUEUED.name,))
|
||||
queued_flags = [Flag(**item) for item in cursor.fetchall()]
|
||||
|
||||
if queued_flags:
|
||||
grouped_flags = defaultdict(list)
|
||||
for item in queued_flags:
|
||||
grouped_flags[item.sploit, item.team].append(item)
|
||||
flags = get_fair_share(grouped_flags.values(), config['SUBMIT_FLAG_LIMIT'])
|
||||
|
||||
app.logger.debug('Submitting %s flags (out of %s in queue)', len(flags), len(queued_flags))
|
||||
results = submit_flags(flags, config)
|
||||
|
||||
rows = [(item.status.name, item.checksystem_response, item.flag) for item in results]
|
||||
cursor.executemany("UPDATE flags SET status = %s, checksystem_response = %s"
|
||||
"WHERE flag = %s", rows)
|
||||
db.commit()
|
||||
|
||||
submit_spent = time.time() - submit_start_time
|
||||
if config['SUBMIT_PERIOD'] > submit_spent:
|
||||
time.sleep(config['SUBMIT_PERIOD'] - submit_spent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_loop()
|
||||
Reference in New Issue
Block a user