Skip to content

Commit

Permalink
Merge pull request #1674 from astaric/refactor-sql
Browse files Browse the repository at this point in the history
[ENH] SQL Server support in SQL widget
  • Loading branch information
astaric authored Oct 22, 2016
2 parents c90e602 + a078fe7 commit 3f9686d
Show file tree
Hide file tree
Showing 12 changed files with 683 additions and 326 deletions.
2 changes: 0 additions & 2 deletions Orange/data/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
import Orange.misc
psycopg2 = Orange.misc.import_late_warning("psycopg2")
11 changes: 11 additions & 0 deletions Orange/data/sql/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .base import Backend

try:
from .postgres import Psycopg2Backend
except ImportError:
pass

try:
from .mssql import PymssqlBackend
except ImportError:
pass
238 changes: 238 additions & 0 deletions Orange/data/sql/backend/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import logging
from contextlib import contextmanager

from Orange.util import Registry

log = logging.getLogger(__name__)


class BackendError(Exception):
pass


class Backend(metaclass=Registry):
"""Base class for SqlTable backends. Implementations should define
all of the methods defined below.
Parameters
----------
connection_params: dict
connection params
"""

display_name = ""

def __init__(self, connection_params):
self.connection_params = connection_params

@classmethod
def available_backends(cls):
"""Return a list of all available backends"""
return cls.registry.values()

# "meta" methods

def list_tables_query(self, schema=None):
"""Return a list of tuples (schema, table_name)
Parameters
----------
schema : Optional[str]
If set, only tables from schema should be listed
Returns
-------
A list of tuples
"""
raise NotImplementedError

def list_tables(self, schema=None):
"""Return a list of tables in database
Parameters
----------
schema : Optional[str]
If set, only tables from given schema will be listed
Returns
-------
A list of TableDesc objects, describing the tables in the database
"""
query = self.list_tables_query(schema)
with self.execute_sql_query(query) as cur:
tables = []
for schema, name in cur.fetchall():
sql = "{}.{}".format(
self.quote_identifier(schema),
self.quote_identifier(name)) if schema else self.quote_identifier(name)
tables.append(TableDesc(name, schema, sql))
return tables

def get_fields(self, table_name):
"""Return a list of field names and metadata in the given table
Parameters
----------
table_name: str
Returns
-------
a list of tuples (field_name, *field_metadata)
both will be passed to create_variable
"""
query = self.create_sql_query(table_name, ["*"], limit=0)
with self.execute_sql_query(query) as cur:
return cur.description

def get_distinct_values(self, field_name, table_name):
"""Return a list of distinct values of field
Parameters
----------
field_name : name of the field
table_name : name of the table or query to search
Returns
-------
List[str] of values
"""
fields = [self.quote_identifier(field_name)]

query = self.create_sql_query(table_name, fields,
group_by=fields, order_by=fields,
limit=21)
with self.execute_sql_query(query) as cur:
values = cur.fetchall()
if len(values) > 20:
return ()
else:
return tuple(str(x[0]) for x in values)

def create_variable(self, field_name, field_metadata,
type_hints, inspect_table=None):
"""Create variable based on field information
Parameters
----------
field_name : str
name do the field
field_metadata : tuple
data to guess field type from
type_hints : Domain
domain with variable templates
inspect_table : Option[str]
name of the table to expect the field values or None
if no inspection is to be performed
Returns
-------
Variable representing the field
"""
raise NotImplementedError

def count_approx(self, query):
"""Return estimated number of rows returned by query.
Parameters
----------
query : str
Returns
-------
Approximate number of rows
"""
raise NotImplementedError

# query related methods

def create_sql_query(
self, table_name, fields, filters=(),
group_by=None, order_by=None, offset=None, limit=None,
use_time_sample=None):
"""Construct an sql query using the provided elements.
Parameters
----------
table_name : str
fields : List[str]
filters : List[str]
group_by: List[str]
order_by: List[str]
offset: int
limit: int
use_time_sample: int
Returns
-------
string containing sql query
"""
raise NotImplementedError

@contextmanager
def execute_sql_query(self, query, params=None):
"""Context manager for execution of sql queries
Usage:
```
with backend.execute_sql_query("SELECT * FROM foo") as cur:
cur.fetch_all()
```
Parameters
----------
query : string
query to be executed
params: tuple
parameters to be passed to the query
Returns
-------
yields a cursor that can be used to access the data
"""
raise NotImplementedError

def quote_identifier(self, name):
"""Quote identifier name so it can be safely used in queries
Parameters
----------
name: str
name of the parameter
Returns
-------
quoted parameter that can be used in sql queries
"""
raise NotImplementedError

def unquote_identifier(self, quoted_name):
"""Remove quotes from identifier name
Used when sql table name is used in where parameter to
query special tables
Parameters
----------
quoted_name : str
Returns
-------
unquoted name
"""
raise NotImplementedError


class TableDesc:
def __init__(self, name, schema, sql):
self.name = name
self.schema = schema
self.sql = sql

def __str__(self):
return self.name

class ToSql:
def __init__(self, sql):
self.sql = sql

def __call__(self):
return self.sql
118 changes: 118 additions & 0 deletions Orange/data/sql/backend/mssql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from contextlib import contextmanager

import pymssql

from Orange.data import StringVariable, TimeVariable, ContinuousVariable, DiscreteVariable
from Orange.data.sql.backend import Backend
from Orange.data.sql.backend.base import ToSql, BackendError


class PymssqlBackend(Backend):
display_name = "SQL Server"

def __init__(self, connection_params):
connection_params["server"] = connection_params.pop("host", None)

for key in list(connection_params):
if connection_params[key] is None:
del connection_params[key]

super().__init__(connection_params)
try:
self.connection = pymssql.connect(**connection_params)
except pymssql.Error as ex:
raise BackendError(str(ex)) from ex

def list_tables_query(self, schema=None):
return """
SELECT [TABLE_SCHEMA], [TABLE_NAME]
FROM information_schema.tables
WHERE TABLE_TYPE='BASE TABLE'
ORDER BY [TABLE_NAME]
"""

def quote_identifier(self, name):
return "[{}]".format(name)

def unquote_identifier(self, quoted_name):
return quoted_name[1:-1]

def create_sql_query(self, table_name, fields, filters=(),
group_by=None, order_by=None, offset=None, limit=None,
use_time_sample=None):
sql = ["SELECT"]
if limit and not offset:
sql.extend(["TOP", str(limit)])
sql.append(', '.join(fields))
sql.extend(["FROM", table_name])
if use_time_sample:
sql.append("TABLESAMPLE system_time(%i)" % use_time_sample)
if filters:
sql.extend(["WHERE", " AND ".join(filters)])
if group_by:
sql.extend(["GROUP BY", ", ".join(group_by)])

if offset and not order_by:
order_by = fields[0].split("AS")[1:]

if order_by:
sql.extend(["ORDER BY", ",".join(order_by)])
if offset:
sql.extend(["OFFSET", str(offset), "ROWS"])
if limit:
sql.extend(["FETCH FIRST", str(limit), "ROWS ONLY"])

return " ".join(sql)

@contextmanager
def execute_sql_query(self, query, params=()):
print(query)
try:
with self.connection.cursor() as cur:
cur.execute(query, *params)
yield cur
finally:
self.connection.commit()

def create_variable(self, field_name, field_metadata, type_hints, inspect_table=None):
if field_name in type_hints:
var = type_hints[field_name]
else:
var = self._guess_variable(field_name, field_metadata,
inspect_table)

field_name_q = self.quote_identifier(field_name)
if var.is_continuous:
if isinstance(var, TimeVariable):
var.to_sql = ToSql("DATEDIFF(s, '1970-01-01 00:00:00', {})".format(field_name_q))
else:
var.to_sql = ToSql(field_name_q)
else: # discrete or string
var.to_sql = ToSql(field_name_q)
return var

def _guess_variable(self, field_name, field_metadata, inspect_table):
from pymssql import STRING, NUMBER, DATETIME, DECIMAL

type_code, *rest = field_metadata

if type_code in (NUMBER, DECIMAL):
return ContinuousVariable(field_name)

if type_code == DATETIME:
tv = TimeVariable(field_name)
tv.have_date = True
tv.have_time = True
return tv

if type_code == STRING:
if inspect_table:
values = [] #self._get_distinct_values(field_name, inspect_table)
if values:
return DiscreteVariable(field_name, values)

return StringVariable(field_name)

def count_approx(self, query):
# TODO: Figure out how to do count estimates on mssql
raise NotImplementedError
Loading

0 comments on commit 3f9686d

Please sign in to comment.