diff -r 3da452f4a3ac roundup/backends/back_postgresql.py --- a/roundup/backends/back_postgresql.py Sat Dec 23 21:59:08 2023 -0500 +++ b/roundup/backends/back_postgresql.py Tue Dec 26 22:19:23 2023 -0500 @@ -9,6 +9,7 @@ import logging import os +import re import shutil import time @@ -47,28 +48,40 @@ del d['read_default_file'] return d +def db_schema_split(database_name): + ''' Split database_name into database and schema parts''' + if '.' in database_name: + return database_name.split ('.') + return [database_name, ''] def db_create(config): """Clear all database contents and drop database itself""" - command = ("CREATE DATABASE \"%s\" WITH ENCODING='UNICODE'" % - get_database_name(config)) - if config.RDBMS_TEMPLATE: - command = command + " TEMPLATE=%s" % config.RDBMS_TEMPLATE - logging.getLogger('roundup.hyperdb').info(command) - db_command(config, command) - + db_name, schema_name = db_schema_split(config.RDBMS_NAME) + if not schema_name: + command = "CREATE DATABASE \"%s\" WITH ENCODING='UNICODE'" % db_name + if config.RDBMS_TEMPLATE: + command = command + " TEMPLATE=%s" % config.RDBMS_TEMPLATE + logging.getLogger('roundup.hyperdb').info(command) + db_command(config, command) + else: + command = "CREATE SCHEMA \"%s\" AUTHORIZATION \"%s\"" % (schema_name, config.RDBMS_USER) + logging.getLogger('roundup.hyperdb').info(command) + db_command(config, command, db_name) def db_nuke(config): - """Clear all database contents and drop database itself""" - command = 'DROP DATABASE "%s"' % get_database_name(config) - - logging.getLogger('roundup.hyperdb').info(command) - db_command(config, command) - + """Drop the database (and all its contents) or the schema.""" + db_name, schema_name = db_schema_split(config.RDBMS_NAME) + if not schema_name: + command = 'DROP DATABASE "%s"'% db_name + logging.getLogger('roundup.hyperdb').info(command) + db_command(config, command) + else: + command = 'DROP SCHEMA "%s" CASCADE' % schema_name + logging.getLogger('roundup.hyperdb').info(command) + db_command(config, command, db_name) if os.path.exists(config.DATABASE): shutil.rmtree(config.DATABASE) - def get_database_name(config): '''Get database name using config.RDBMS_NAME or config.RDBMS_SERVICE. @@ -124,14 +137,16 @@ before "template1" seems to have been used, so we fall back to it. Compare to issue2550543. ''' - template1 = connection_dict(config) + template1 = connection_dict(config, 'database') + db_name, schema_name = db_schema_split(template1['database']) template1['database'] = database try: conn = psycopg2.connect(**template1) except psycopg2.OperationalError as message: - if str(message).find('database "postgres" does not exist') >= 0: - return db_command(config, command, database='template1') + if not schema_name: + if re.search(r'database ".+" does not exist', str(message)): + return db_command(config, command, database='template1') raise hyperdb.DatabaseError(message) conn.set_isolation_level(0) @@ -142,17 +157,17 @@ return finally: conn.close() - raise RuntimeError('10 attempts to create database failed when running: %s' % command) + raise RuntimeError('10 attempts to create database or schema failed when running: %s' % command) -def pg_command(cursor, command): +def pg_command(cursor, command, args=()): '''Execute the postgresql command, which may be blocked by some other user connecting to the database, and return a true value if it succeeds. If there is a concurrent update, retry the command. ''' try: - cursor.execute(command) + cursor.execute(command, args) except psycopg2.DatabaseError as err: response = str(err).split('\n')[0] if "FATAL" not in response: @@ -164,19 +179,32 @@ if msg in response: time.sleep(0.1) return 0 - raise RuntimeError(response) + raise RuntimeError(response, command, args) return 1 def db_exists(config): - """Check if database already exists""" + """Check if database or schema already exists""" db = connection_dict(config, 'database') + db_name, schema_name = db_schema_split(db['database']) + if schema_name: + db['database'] = db_name try: conn = psycopg2.connect(**db) - conn.close() - return 1 + if not schema_name: + conn.close() + return 1 except Exception: return 0 + # will have a non-false value here; otherwise one + # of the above returns would have returned. + # Get a count of the number of schemas named (either 0 or 1). + command = "SELECT COUNT(*) FROM information_schema.schemata WHERE schema_name = %s" + cursor = conn.cursor() + pg_command(cursor, command, (schema_name,)) + count = cursor.fetchall()[0][0] + conn.close() + return count # 'count' will be 0 or 1. class Sessions(sessions_rdbms.Sessions): @@ -225,6 +253,10 @@ def sql_open_connection(self): db = connection_dict(self.config, 'database') + db_name, schema_name = db_schema_split (db['database']) + if schema_name: + db['database'] = db_name + # database option always present: log it if not null if db['database']: logging.getLogger('roundup.hyperdb').info( @@ -242,6 +274,11 @@ lvl = isolation_levels[self.config.RDBMS_ISOLATION_LEVEL] conn.set_isolation_level(lvl) + if schema_name: + self.sql ('SET search_path TO %s' % schema_name, cursor=cursor) + # Commit is required so that a subsequent rollback + # will not also rollback the search_path change. + self.sql ('COMMIT', cursor=cursor) return (conn, cursor) def sql_new_cursor(self, name='default', conn=None, *args, **kw):