202 lines
5.4 KiB
Python
202 lines
5.4 KiB
Python
"""
|
|
Module with PostgreSQL helpers
|
|
"""
|
|
|
|
import os
|
|
import psycopg2
|
|
import threading
|
|
from psycopg2.extras import RealDictCursor
|
|
|
|
from flask import g
|
|
|
|
from __init__ import app
|
|
|
|
|
|
schema_path = os.path.join(os.path.dirname(__file__), 'schema.sql')
|
|
|
|
# Database configuration
|
|
db_host = os.getenv('DB_HOST', 'postgres')
|
|
db_port = os.getenv('DB_PORT', '5432')
|
|
db_name = os.getenv('DB_NAME', 'farmdb')
|
|
db_user = os.getenv('DB_USER', 'farm')
|
|
db_password = os.getenv('DB_PASSWORD', 'asdasdasd')
|
|
|
|
_init_started = False
|
|
_init_lock = threading.RLock()
|
|
|
|
|
|
def dict_factory(cursor, row):
|
|
"""
|
|
Convert database row to dictionary similar to sqlite3.Row
|
|
"""
|
|
return {col[0]: row[i] for i, col in enumerate(cursor.description)}
|
|
|
|
|
|
def get(context_bound=True):
|
|
"""
|
|
If there is no opened connection to the PostgreSQL database in the context
|
|
of the current request or if context_bound=False, get() opens a new
|
|
connection to the PostgreSQL database. Reopening the connection on each request
|
|
may have some overhead, but allows to avoid implementing a pool of
|
|
thread-local connections.
|
|
|
|
If the database schema needs initialization, get() creates and initializes it.
|
|
If get() is called from other threads at this time, they will wait
|
|
for the end of the initialization.
|
|
|
|
If context_bound=True, the connection will be closed after
|
|
request handling (when the context will be destroyed).
|
|
|
|
:returns: a connection to the initialized PostgreSQL database
|
|
"""
|
|
|
|
global _init_started
|
|
|
|
if context_bound and 'database' in g:
|
|
return g.database
|
|
|
|
# Connect to PostgreSQL database
|
|
database = psycopg2.connect(
|
|
host=db_host,
|
|
port=db_port,
|
|
dbname=db_name,
|
|
user=db_user,
|
|
password=db_password
|
|
)
|
|
|
|
# Check if initialization is needed
|
|
need_init = check_if_initialization_needed(database)
|
|
|
|
if need_init:
|
|
with _init_lock:
|
|
if not _init_started:
|
|
_init_started = True
|
|
_init(database)
|
|
|
|
if context_bound:
|
|
g.database = database
|
|
|
|
app.logger.info('DB connection established')
|
|
return database
|
|
|
|
|
|
def check_if_initialization_needed(conn):
|
|
"""
|
|
Check if database needs initialization by looking for specific tables.
|
|
Modify this based on your application's requirements.
|
|
"""
|
|
try:
|
|
cursor = conn.cursor()
|
|
# Check for existence of any table to see if DB is initialized
|
|
cursor.execute("""
|
|
SELECT EXISTS (
|
|
SELECT FROM information_schema.tables
|
|
WHERE table_schema = 'public'
|
|
LIMIT 1
|
|
);
|
|
""")
|
|
tables_exist = cursor.fetchone()[0]
|
|
cursor.close()
|
|
return not tables_exist
|
|
except Exception as e:
|
|
app.logger.error(f"Error checking initialization status: {e}")
|
|
return True
|
|
|
|
|
|
def _init(conn):
|
|
"""
|
|
Initialize the database schema and any required data.
|
|
"""
|
|
try:
|
|
cursor = conn.cursor()
|
|
|
|
# Read and execute schema.sql if it exists
|
|
if os.path.exists(schema_path):
|
|
with open(schema_path, 'r') as f:
|
|
schema_sql = f.read()
|
|
cursor.execute(schema_sql)
|
|
app.logger.info("Executed schema.sql")
|
|
else:
|
|
# Fallback to basic tables if schema.sql doesn't exist
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS flags (
|
|
id SERIAL PRIMARY KEY,
|
|
flag TEXT UNIQUE NOT NULL,
|
|
team INTEGER NOT NULL,
|
|
tick INTEGER NOT NULL,
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
""")
|
|
cursor.execute("""
|
|
CREATE TABLE IF NOT EXISTS submissions (
|
|
id SERIAL PRIMARY KEY,
|
|
flag TEXT NOT NULL,
|
|
team INTEGER NOT NULL,
|
|
tick INTEGER NOT NULL,
|
|
submitted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
);
|
|
""")
|
|
|
|
conn.commit()
|
|
cursor.close()
|
|
app.logger.info("Database initialized successfully")
|
|
|
|
except Exception as e:
|
|
conn.rollback()
|
|
app.logger.error(f"Database initialization failed: {e}")
|
|
raise
|
|
|
|
|
|
def query(sql, args=()):
|
|
"""
|
|
Execute a query and return results as dictionaries
|
|
"""
|
|
conn = get()
|
|
cursor = conn.cursor()
|
|
cursor.execute(sql, args)
|
|
|
|
if cursor.description: # If it's a SELECT query
|
|
results = [dict_factory(cursor, row) for row in cursor.fetchall()]
|
|
else: # For INSERT, UPDATE, DELETE
|
|
results = None
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
return results
|
|
|
|
|
|
def execute(sql, args=()):
|
|
"""
|
|
Execute a query that doesn't return results (INSERT, UPDATE, DELETE)
|
|
"""
|
|
conn = get()
|
|
cursor = conn.cursor()
|
|
cursor.execute(sql, args)
|
|
conn.commit()
|
|
cursor.close()
|
|
|
|
|
|
def fetch_one(sql, args=()):
|
|
"""
|
|
Execute a query and return first result as dictionary
|
|
"""
|
|
conn = get()
|
|
cursor = conn.cursor()
|
|
cursor.execute(sql, args)
|
|
|
|
if cursor.description: # If it's a SELECT query
|
|
row = cursor.fetchone()
|
|
result = dict_factory(cursor, row) if row else None
|
|
else:
|
|
result = None
|
|
conn.commit()
|
|
|
|
cursor.close()
|
|
return result
|
|
|
|
|
|
@app.teardown_appcontext
|
|
def close(_):
|
|
if 'database' in g:
|
|
g.database.close()
|