postgres version
This commit is contained in:
201
server/database.py
Normal file
201
server/database.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Module with PostgreSQL helpers
|
||||
"""
|
||||
|
||||
import os
|
||||
import psycopg2
|
||||
import threading
|
||||
from psycopg2.extras import RealDictCursor
|
||||
|
||||
from flask import g
|
||||
|
||||
from __init__ import app
|
||||
|
||||
|
||||
schema_path = os.path.join(os.path.dirname(__file__), 'schema.sql')
|
||||
|
||||
# Database configuration
|
||||
db_host = os.getenv('DB_HOST', 'postgres')
|
||||
db_port = os.getenv('DB_PORT', '5432')
|
||||
db_name = os.getenv('DB_NAME', 'farmdb')
|
||||
db_user = os.getenv('DB_USER', 'farm')
|
||||
db_password = os.getenv('DB_PASSWORD', 'asdasdasd')
|
||||
|
||||
_init_started = False
|
||||
_init_lock = threading.RLock()
|
||||
|
||||
|
||||
def dict_factory(cursor, row):
|
||||
"""
|
||||
Convert database row to dictionary similar to sqlite3.Row
|
||||
"""
|
||||
return {col[0]: row[i] for i, col in enumerate(cursor.description)}
|
||||
|
||||
|
||||
def get(context_bound=True):
|
||||
"""
|
||||
If there is no opened connection to the PostgreSQL database in the context
|
||||
of the current request or if context_bound=False, get() opens a new
|
||||
connection to the PostgreSQL database. Reopening the connection on each request
|
||||
may have some overhead, but allows to avoid implementing a pool of
|
||||
thread-local connections.
|
||||
|
||||
If the database schema needs initialization, get() creates and initializes it.
|
||||
If get() is called from other threads at this time, they will wait
|
||||
for the end of the initialization.
|
||||
|
||||
If context_bound=True, the connection will be closed after
|
||||
request handling (when the context will be destroyed).
|
||||
|
||||
:returns: a connection to the initialized PostgreSQL database
|
||||
"""
|
||||
|
||||
global _init_started
|
||||
|
||||
if context_bound and 'database' in g:
|
||||
return g.database
|
||||
|
||||
# Connect to PostgreSQL database
|
||||
database = psycopg2.connect(
|
||||
host=db_host,
|
||||
port=db_port,
|
||||
dbname=db_name,
|
||||
user=db_user,
|
||||
password=db_password
|
||||
)
|
||||
|
||||
# Check if initialization is needed
|
||||
need_init = check_if_initialization_needed(database)
|
||||
|
||||
if need_init:
|
||||
with _init_lock:
|
||||
if not _init_started:
|
||||
_init_started = True
|
||||
_init(database)
|
||||
|
||||
if context_bound:
|
||||
g.database = database
|
||||
|
||||
app.logger.info('DB connection established')
|
||||
return database
|
||||
|
||||
|
||||
def check_if_initialization_needed(conn):
|
||||
"""
|
||||
Check if database needs initialization by looking for specific tables.
|
||||
Modify this based on your application's requirements.
|
||||
"""
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
# Check for existence of any table to see if DB is initialized
|
||||
cursor.execute("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
LIMIT 1
|
||||
);
|
||||
""")
|
||||
tables_exist = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return not tables_exist
|
||||
except Exception as e:
|
||||
app.logger.error(f"Error checking initialization status: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def _init(conn):
|
||||
"""
|
||||
Initialize the database schema and any required data.
|
||||
"""
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Read and execute schema.sql if it exists
|
||||
if os.path.exists(schema_path):
|
||||
with open(schema_path, 'r') as f:
|
||||
schema_sql = f.read()
|
||||
cursor.execute(schema_sql)
|
||||
app.logger.info("Executed schema.sql")
|
||||
else:
|
||||
# Fallback to basic tables if schema.sql doesn't exist
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS flags (
|
||||
id SERIAL PRIMARY KEY,
|
||||
flag TEXT UNIQUE NOT NULL,
|
||||
team INTEGER NOT NULL,
|
||||
tick INTEGER NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS submissions (
|
||||
id SERIAL PRIMARY KEY,
|
||||
flag TEXT NOT NULL,
|
||||
team INTEGER NOT NULL,
|
||||
tick INTEGER NOT NULL,
|
||||
submitted_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
app.logger.info("Database initialized successfully")
|
||||
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
app.logger.error(f"Database initialization failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def query(sql, args=()):
|
||||
"""
|
||||
Execute a query and return results as dictionaries
|
||||
"""
|
||||
conn = get()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, args)
|
||||
|
||||
if cursor.description: # If it's a SELECT query
|
||||
results = [dict_factory(cursor, row) for row in cursor.fetchall()]
|
||||
else: # For INSERT, UPDATE, DELETE
|
||||
results = None
|
||||
conn.commit()
|
||||
|
||||
cursor.close()
|
||||
return results
|
||||
|
||||
|
||||
def execute(sql, args=()):
|
||||
"""
|
||||
Execute a query that doesn't return results (INSERT, UPDATE, DELETE)
|
||||
"""
|
||||
conn = get()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, args)
|
||||
conn.commit()
|
||||
cursor.close()
|
||||
|
||||
|
||||
def fetch_one(sql, args=()):
|
||||
"""
|
||||
Execute a query and return first result as dictionary
|
||||
"""
|
||||
conn = get()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(sql, args)
|
||||
|
||||
if cursor.description: # If it's a SELECT query
|
||||
row = cursor.fetchone()
|
||||
result = dict_factory(cursor, row) if row else None
|
||||
else:
|
||||
result = None
|
||||
conn.commit()
|
||||
|
||||
cursor.close()
|
||||
return result
|
||||
|
||||
|
||||
@app.teardown_appcontext
|
||||
def close(_):
|
||||
if 'database' in g:
|
||||
g.database.close()
|
||||
Reference in New Issue
Block a user