diff --git a/blitzortung/db/__init__.py b/blitzortung/db/__init__.py index 30aa284..c6bf5cd 100644 --- a/blitzortung/db/__init__.py +++ b/blitzortung/db/__init__.py @@ -26,15 +26,15 @@ def create_psycopg2_dummy(): - class Dummy(object): + class Dummy: Binary = MagicMock(name="psycopg2.Binary") pass - dummy = Dummy() - dummy.pool = Dummy() - dummy.pool.ThreadedConnectionPool = Dummy - dummy.extensions = mock.Mock() - dummy.extras = mock.Mock() + dummy = mock.Mock(name='psycopg2') +# dummy.pool = Dumm() + dummy.pool.ThreadedConnectionPool = Dummy() +# dummy.extensions = mock.Mock() +# dummy.extras = mock.Mock() return dummy diff --git a/blitzortung/db/table.py b/blitzortung/db/table.py index 5c93d96..f285cc7 100644 --- a/blitzortung/db/table.py +++ b/blitzortung/db/table.py @@ -43,40 +43,12 @@ from abc import ABCMeta, abstractmethod -class Base(object): - """ - abstract base class for database access objects - - creation of database - - as user postgres: - - createuser -i -D -R -S -W -E -P blitzortung - createdb -E utf8 -O blitzortung blitzortung - createlang plpgsql blitzortung - psql -f /usr/share/postgresql/10/contrib/postgis-2.4/postgis.sql -d blitzortung - psql -f /usr/share/postgresql/12/contrib/postgis-3.0/postgis.sql -d blitzortung - psql -f /usr/share/postgresql/10/contrib/postgis-2.4/spatial_ref_sys.sql -d blitzortung - - psql blitzortung - - GRANT SELECT ON spatial_ref_sys TO blitzortung; - GRANT SELECT ON geometry_columns TO blitzortung; - GRANT INSERT, DELETE ON geometry_columns TO blitzortung; - CREATE EXTENSION "btree_gist"; - - """ - __metaclass__ = ABCMeta - - DefaultTimezone = datetime.timezone.utc +class DbWrapper: def __init__(self, db_connection_pool): - - self.logger = logging.getLogger(get_logger_name(self.__class__)) self.db_connection_pool = db_connection_pool - self.schema_name = "" - self.table_name = "" + self.logger = logging.getLogger(get_logger_name(self.__class__)) while True: self.conn = self.db_connection_pool.getconn() @@ -92,10 +64,6 @@ def __init__(self, db_connection_pool): psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY, self.conn) self.conn.set_client_encoding('UTF8') - self.srid = geom.Geometry.DefaultSrid - self.tz = None - self.set_timezone(Base.DefaultTimezone) - cur = None try: cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) @@ -124,6 +92,80 @@ def is_connected(self): else: return False + def set_timezone(self, tz): + with self.conn.cursor() as cur: + cur.execute('SET TIME ZONE \'%s\'' % str(tz)) + + def commit(self): + """ commit pending database transaction """ + self.conn.commit() + + def rollback(self): + """ rollback pending database transaction """ + self.conn.rollback() + + def execute(self, sql_statement, parameters=None, factory_method=None, **factory_method_args): + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: + cursor.execute(sql_statement, parameters) + if factory_method: + method = factory_method(cursor, **factory_method_args) + return method + + def execute_single(self, sql_statement, parameters=None, factory_method=None, **factory_method_args): + def single_cursor_factory(cursor): + if cursor.rowcount == 1: + return factory_method(cursor.fetchone(), **factory_method_args) + + return self.execute(sql_statement, parameters, single_cursor_factory) + + def execute_many(self, sql_statement, parameters=None, factory_method=None, **factory_method_args): + with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: + cursor.execute(sql_statement, parameters) + if factory_method: + for value in cursor: + yield factory_method(value, **factory_method_args) + + +class Base: + """ + abstract base class for database access objects + + creation of database + + as user postgres: + + createuser -i -D -R -S -W -E -P blitzortung + createdb -E utf8 -O blitzortung blitzortung + createlang plpgsql blitzortung + psql -f /usr/share/postgresql/10/contrib/postgis-2.4/postgis.sql -d blitzortung + psql -f /usr/share/postgresql/12/contrib/postgis-3.0/postgis.sql -d blitzortung + psql -f /usr/share/postgresql/10/contrib/postgis-2.4/spatial_ref_sys.sql -d blitzortung + psql -f /usr/share/postgresql/12/contrib/postgis-3.0/spatial_ref_sys.sql -d blitzortung + + psql blitzortung + + GRANT SELECT ON spatial_ref_sys TO blitzortung; + GRANT SELECT ON geometry_columns TO blitzortung; + GRANT INSERT, DELETE ON geometry_columns TO blitzortung; + CREATE EXTENSION "btree_gist"; + + """ + __metaclass__ = ABCMeta + + DefaultTimezone = datetime.timezone.utc + + def __init__(self, db_wrapper: DbWrapper): + self.db_wrapper = db_wrapper + + self.logger = logging.getLogger(get_logger_name(self.__class__)) + + self.schema_name = "" + self.table_name = "" + + self.srid = geom.Geometry.DefaultSrid + self.tz = None + self.set_timezone(Base.DefaultTimezone) + @property def full_table_name(self): if self.schema_name: @@ -142,8 +184,7 @@ def get_timezone(self): def set_timezone(self, tz): self.tz = tz - with self.conn.cursor() as cur: - cur.execute('SET TIME ZONE \'%s\'' % str(self.tz)) + self.db_wrapper.set_timezone(tz) def fix_timezone(self, timestamp): return timestamp.astimezone(self.tz) if timestamp else None @@ -157,11 +198,11 @@ def from_timezone_to_bare_utc(time_with_tz): def commit(self): """ commit pending database transaction """ - self.conn.commit() + self.db_wrapper.commit() def rollback(self): """ rollback pending database transaction """ - self.conn.rollback() + self.db_wrapper.rollback() @abstractmethod def insert(self, *args): @@ -172,25 +213,13 @@ def select(self, **kwargs): pass def execute(self, sql_statement, parameters=None, factory_method=None, **factory_method_args): - with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: - cursor.execute(sql_statement, parameters) - if factory_method: - method = factory_method(cursor, **factory_method_args) - return method + return self.db_wrapper.execute(sql_statement, parameters, factory_method, **factory_method_args) def execute_single(self, sql_statement, parameters=None, factory_method=None, **factory_method_args): - def single_cursor_factory(cursor): - if cursor.rowcount == 1: - return factory_method(cursor.fetchone(), **factory_method_args) - - return self.execute(sql_statement, parameters, single_cursor_factory) + return self.db_wrapper.execute_single(sql_statement, parameters, factory_method, **factory_method_args) def execute_many(self, sql_statement, parameters=None, factory_method=None, **factory_method_args): - with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor: - cursor.execute(sql_statement, parameters) - if factory_method: - for value in cursor: - yield factory_method(value, **factory_method_args) + return self.db_wrapper.execute_many(sql_statement, parameters, factory_method, **factory_method_args) class Strike(Base): @@ -210,10 +239,12 @@ class Strike(Base): ALTER TABLE strikes ADD COLUMN stationcount SMALLINT; CREATE INDEX strikes_timestamp ON strikes USING btree("timestamp"); + CREATE INDEX strikes_timestamp_geog ON strikes USING gist("timestamp", geog); + + not really required in standard mode: CREATE INDEX strikes_region_timestamp_nanoseconds ON strikes USING btree(region, "timestamp", nanoseconds); CREATE INDEX strikes_id_timestamp ON strikes USING btree(id, "timestamp"); CREATE INDEX strikes_geog ON strikes USING gist(geog); - CREATE INDEX strikes_timestamp_geog ON strikes USING gist("timestamp", geog); CREATE INDEX strikes_id_timestamp_geog ON strikes USING gist(id, "timestamp", geog); empty the table with the following commands: diff --git a/tests/db/test_db_table.py b/tests/db/test_db_table.py index 4fdc17b..3fe2b2d 100644 --- a/tests/db/test_db_table.py +++ b/tests/db/test_db_table.py @@ -24,6 +24,8 @@ except ImportError: from backports.zoneinfo import ZoneInfo +from blitzortung import builder + try: import psycopg2 except ImportError as e: @@ -31,6 +33,8 @@ psycopg2 = create_psycopg2_dummy() +from blitzortung.db import query_builder, mapper + from assertpy import assert_that from mock import Mock, call @@ -140,3 +144,36 @@ def test_rollback(self): self.base.rollback() self.connection.rollback.assert_called_once_with() + + +class StrikeForTest(blitzortung.db.table.Strike): + def __init__(self, db_connection_pool): + super().__init__(db_connection_pool, query_builder.Strike(), mapper.Strike(builder.Strike())) + + def create_object_instance(self, result): + return result + + def insert(self, *args): + return args + + def select(self, *args): + return args + + +class TestStrike: + def setup_method(self): + self.connection_pool = Mock() + self.connection = self.connection_pool.getconn() + self.cursor = self.connection.cursor() + + psycopg2.extensions = Mock() + + self.cursor.__enter__ = Mock(return_value=self.cursor) + self.cursor.__exit__ = Mock(return_value=False) + + self.base = StrikeForTest(self.connection_pool) + + def test_latest_time(self): + result = self.base.get_latest_time(2) + + assert_that(result).is_equal_to("asdf")