Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add db wrapper #18

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions blitzortung/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@


def create_psycopg2_dummy():
class Dummy(object):
class Dummy:
Binary = MagicMock(name="psycopg2.Binary")
pass

dummy = Dummy()
dummy.pool = Dummy()
dummy.pool.ThreadedConnectionPool = Dummy
dummy.extensions = mock.Mock()
dummy.extras = mock.Mock()
dummy = mock.Mock(name='psycopg2')
# dummy.pool = Dumm()
dummy.pool.ThreadedConnectionPool = Dummy()
# dummy.extensions = mock.Mock()
# dummy.extras = mock.Mock()
return dummy


Expand Down
139 changes: 85 additions & 54 deletions blitzortung/db/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,40 +43,12 @@
from abc import ABCMeta, abstractmethod


class Base(object):
"""
abstract base class for database access objects

creation of database

as user postgres:

createuser -i -D -R -S -W -E -P blitzortung
createdb -E utf8 -O blitzortung blitzortung
createlang plpgsql blitzortung
psql -f /usr/share/postgresql/10/contrib/postgis-2.4/postgis.sql -d blitzortung
psql -f /usr/share/postgresql/12/contrib/postgis-3.0/postgis.sql -d blitzortung
psql -f /usr/share/postgresql/10/contrib/postgis-2.4/spatial_ref_sys.sql -d blitzortung

psql blitzortung

GRANT SELECT ON spatial_ref_sys TO blitzortung;
GRANT SELECT ON geometry_columns TO blitzortung;
GRANT INSERT, DELETE ON geometry_columns TO blitzortung;
CREATE EXTENSION "btree_gist";

"""
__metaclass__ = ABCMeta

DefaultTimezone = datetime.timezone.utc
class DbWrapper:

def __init__(self, db_connection_pool):

self.logger = logging.getLogger(get_logger_name(self.__class__))
self.db_connection_pool = db_connection_pool

self.schema_name = ""
self.table_name = ""
self.logger = logging.getLogger(get_logger_name(self.__class__))

while True:
self.conn = self.db_connection_pool.getconn()
Expand All @@ -92,10 +64,6 @@ def __init__(self, db_connection_pool):
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY, self.conn)
self.conn.set_client_encoding('UTF8')

self.srid = geom.Geometry.DefaultSrid
self.tz = None
self.set_timezone(Base.DefaultTimezone)

cur = None
try:
cur = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
Expand Down Expand Up @@ -124,6 +92,80 @@ def is_connected(self):
else:
return False

def set_timezone(self, tz):
with self.conn.cursor() as cur:
cur.execute('SET TIME ZONE \'%s\'' % str(tz))

def commit(self):
""" commit pending database transaction """
self.conn.commit()

def rollback(self):
""" rollback pending database transaction """
self.conn.rollback()

def execute(self, sql_statement, parameters=None, factory_method=None, **factory_method_args):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
cursor.execute(sql_statement, parameters)
if factory_method:
method = factory_method(cursor, **factory_method_args)
return method

def execute_single(self, sql_statement, parameters=None, factory_method=None, **factory_method_args):
def single_cursor_factory(cursor):
if cursor.rowcount == 1:
return factory_method(cursor.fetchone(), **factory_method_args)

return self.execute(sql_statement, parameters, single_cursor_factory)

def execute_many(self, sql_statement, parameters=None, factory_method=None, **factory_method_args):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
cursor.execute(sql_statement, parameters)
if factory_method:
for value in cursor:
yield factory_method(value, **factory_method_args)


class Base:
"""
abstract base class for database access objects

creation of database

as user postgres:

createuser -i -D -R -S -W -E -P blitzortung
createdb -E utf8 -O blitzortung blitzortung
createlang plpgsql blitzortung
psql -f /usr/share/postgresql/10/contrib/postgis-2.4/postgis.sql -d blitzortung
psql -f /usr/share/postgresql/12/contrib/postgis-3.0/postgis.sql -d blitzortung
psql -f /usr/share/postgresql/10/contrib/postgis-2.4/spatial_ref_sys.sql -d blitzortung
psql -f /usr/share/postgresql/12/contrib/postgis-3.0/spatial_ref_sys.sql -d blitzortung

psql blitzortung

GRANT SELECT ON spatial_ref_sys TO blitzortung;
GRANT SELECT ON geometry_columns TO blitzortung;
GRANT INSERT, DELETE ON geometry_columns TO blitzortung;
CREATE EXTENSION "btree_gist";

"""
__metaclass__ = ABCMeta

DefaultTimezone = datetime.timezone.utc

def __init__(self, db_wrapper: DbWrapper):
self.db_wrapper = db_wrapper

self.logger = logging.getLogger(get_logger_name(self.__class__))

self.schema_name = ""
self.table_name = ""

self.srid = geom.Geometry.DefaultSrid
self.tz = None
self.set_timezone(Base.DefaultTimezone)

@property
def full_table_name(self):
if self.schema_name:
Expand All @@ -142,8 +184,7 @@ def get_timezone(self):

def set_timezone(self, tz):
self.tz = tz
with self.conn.cursor() as cur:
cur.execute('SET TIME ZONE \'%s\'' % str(self.tz))
self.db_wrapper.set_timezone(tz)

def fix_timezone(self, timestamp):
return timestamp.astimezone(self.tz) if timestamp else None
Expand All @@ -157,11 +198,11 @@ def from_timezone_to_bare_utc(time_with_tz):

def commit(self):
""" commit pending database transaction """
self.conn.commit()
self.db_wrapper.commit()

def rollback(self):
""" rollback pending database transaction """
self.conn.rollback()
self.db_wrapper.rollback()

@abstractmethod
def insert(self, *args):
Expand All @@ -172,25 +213,13 @@ def select(self, **kwargs):
pass

def execute(self, sql_statement, parameters=None, factory_method=None, **factory_method_args):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
cursor.execute(sql_statement, parameters)
if factory_method:
method = factory_method(cursor, **factory_method_args)
return method
return self.db_wrapper.execute(sql_statement, parameters, factory_method, **factory_method_args)

def execute_single(self, sql_statement, parameters=None, factory_method=None, **factory_method_args):
def single_cursor_factory(cursor):
if cursor.rowcount == 1:
return factory_method(cursor.fetchone(), **factory_method_args)

return self.execute(sql_statement, parameters, single_cursor_factory)
return self.db_wrapper.execute_single(sql_statement, parameters, factory_method, **factory_method_args)

def execute_many(self, sql_statement, parameters=None, factory_method=None, **factory_method_args):
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cursor:
cursor.execute(sql_statement, parameters)
if factory_method:
for value in cursor:
yield factory_method(value, **factory_method_args)
return self.db_wrapper.execute_many(sql_statement, parameters, factory_method, **factory_method_args)


class Strike(Base):
Expand All @@ -210,10 +239,12 @@ class Strike(Base):
ALTER TABLE strikes ADD COLUMN stationcount SMALLINT;

CREATE INDEX strikes_timestamp ON strikes USING btree("timestamp");
CREATE INDEX strikes_timestamp_geog ON strikes USING gist("timestamp", geog);

not really required in standard mode:
CREATE INDEX strikes_region_timestamp_nanoseconds ON strikes USING btree(region, "timestamp", nanoseconds);
CREATE INDEX strikes_id_timestamp ON strikes USING btree(id, "timestamp");
CREATE INDEX strikes_geog ON strikes USING gist(geog);
CREATE INDEX strikes_timestamp_geog ON strikes USING gist("timestamp", geog);
CREATE INDEX strikes_id_timestamp_geog ON strikes USING gist(id, "timestamp", geog);

empty the table with the following commands:
Expand Down
37 changes: 37 additions & 0 deletions tests/db/test_db_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
except ImportError:
from backports.zoneinfo import ZoneInfo

from blitzortung import builder

try:
import psycopg2
except ImportError as e:
from blitzortung.db import create_psycopg2_dummy

psycopg2 = create_psycopg2_dummy()

from blitzortung.db import query_builder, mapper

from assertpy import assert_that
from mock import Mock, call

Expand Down Expand Up @@ -140,3 +144,36 @@ def test_rollback(self):
self.base.rollback()

self.connection.rollback.assert_called_once_with()


class StrikeForTest(blitzortung.db.table.Strike):
def __init__(self, db_connection_pool):
super().__init__(db_connection_pool, query_builder.Strike(), mapper.Strike(builder.Strike()))

def create_object_instance(self, result):
return result

def insert(self, *args):
return args

def select(self, *args):
return args


class TestStrike:
def setup_method(self):
self.connection_pool = Mock()
self.connection = self.connection_pool.getconn()
self.cursor = self.connection.cursor()

psycopg2.extensions = Mock()

self.cursor.__enter__ = Mock(return_value=self.cursor)
self.cursor.__exit__ = Mock(return_value=False)

self.base = StrikeForTest(self.connection_pool)

def test_latest_time(self):
result = self.base.get_latest_time(2)

assert_that(result).is_equal_to("asdf")