Files
moded_distructive_farm/server/database.py
2025-10-22 21:57:07 +03:00

202 lines
5.4 KiB
Python

"""
Module with PostgreSQL helpers
"""
import os
import psycopg2
import threading
from psycopg2.extras import RealDictCursor
from flask import g
from __init__ import app
schema_path = os.path.join(os.path.dirname(__file__), 'schema.sql')
# Database configuration
db_host = os.getenv('DB_HOST', 'postgres')
db_port = os.getenv('DB_PORT', '5432')
db_name = os.getenv('DB_NAME', 'farmdb')
db_user = os.getenv('DB_USER', 'farm')
db_password = os.getenv('DB_PASSWORD', 'asdasdasd')
_init_started = False
_init_lock = threading.RLock()
def dict_factory(cursor, row):
"""
Convert database row to dictionary similar to sqlite3.Row
"""
return {col[0]: row[i] for i, col in enumerate(cursor.description)}
def get(context_bound=True):
"""
If there is no opened connection to the PostgreSQL database in the context
of the current request or if context_bound=False, get() opens a new
connection to the PostgreSQL database. Reopening the connection on each request
may have some overhead, but allows to avoid implementing a pool of
thread-local connections.
If the database schema needs initialization, get() creates and initializes it.
If get() is called from other threads at this time, they will wait
for the end of the initialization.
If context_bound=True, the connection will be closed after
request handling (when the context will be destroyed).
:returns: a connection to the initialized PostgreSQL database
"""
global _init_started
if context_bound and 'database' in g:
return g.database
# Connect to PostgreSQL database
database = psycopg2.connect(
host=db_host,
port=db_port,
dbname=db_name,
user=db_user,
password=db_password
)
# Check if initialization is needed
need_init = check_if_initialization_needed(database)
if need_init:
with _init_lock:
if not _init_started:
_init_started = True
_init(database)
if context_bound:
g.database = database
app.logger.info('DB connection established')
return database
def check_if_initialization_needed(conn):
"""
Check if database needs initialization by looking for specific tables.
Modify this based on your application's requirements.
"""
try:
cursor = conn.cursor()
# Check for existence of any table to see if DB is initialized
cursor.execute("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public'
LIMIT 1
);
""")
tables_exist = cursor.fetchone()[0]
cursor.close()
return not tables_exist
except Exception as e:
app.logger.error(f"Error checking initialization status: {e}")
return True
def _init(conn):
"""
Initialize the database schema and any required data.
"""
try:
cursor = conn.cursor()
# Read and execute schema.sql if it exists
if os.path.exists(schema_path):
with open(schema_path, 'r') as f:
schema_sql = f.read()
cursor.execute(schema_sql)
app.logger.info("Executed schema.sql")
else:
# Fallback to basic tables if schema.sql doesn't exist
cursor.execute("""
CREATE TABLE IF NOT EXISTS flags (
id SERIAL PRIMARY KEY,
flag TEXT UNIQUE NOT NULL,
team INTEGER NOT NULL,
tick INTEGER NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
cursor.execute("""
CREATE TABLE IF NOT EXISTS submissions (
id SERIAL PRIMARY KEY,
flag TEXT NOT NULL,
team INTEGER NOT NULL,
tick INTEGER NOT NULL,
submitted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
""")
conn.commit()
cursor.close()
app.logger.info("Database initialized successfully")
except Exception as e:
conn.rollback()
app.logger.error(f"Database initialization failed: {e}")
raise
def query(sql, args=()):
"""
Execute a query and return results as dictionaries
"""
conn = get()
cursor = conn.cursor()
cursor.execute(sql, args)
if cursor.description: # If it's a SELECT query
results = [dict_factory(cursor, row) for row in cursor.fetchall()]
else: # For INSERT, UPDATE, DELETE
results = None
conn.commit()
cursor.close()
return results
def execute(sql, args=()):
"""
Execute a query that doesn't return results (INSERT, UPDATE, DELETE)
"""
conn = get()
cursor = conn.cursor()
cursor.execute(sql, args)
conn.commit()
cursor.close()
def fetch_one(sql, args=()):
"""
Execute a query and return first result as dictionary
"""
conn = get()
cursor = conn.cursor()
cursor.execute(sql, args)
if cursor.description: # If it's a SELECT query
row = cursor.fetchone()
result = dict_factory(cursor, row) if row else None
else:
result = None
conn.commit()
cursor.close()
return result
@app.teardown_appcontext
def close(_):
if 'database' in g:
g.database.close()