Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] SQL Server support in SQL widget #1674

Merged
merged 13 commits into from
Oct 22, 2016
2 changes: 0 additions & 2 deletions Orange/data/sql/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
import Orange.misc
psycopg2 = Orange.misc.import_late_warning("psycopg2")
11 changes: 11 additions & 0 deletions Orange/data/sql/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .base import Backend

try:
from .postgres import Psycopg2Backend
except ImportError:
pass

try:
from .mssql import PymssqlBackend
except ImportError:
pass
238 changes: 238 additions & 0 deletions Orange/data/sql/backend/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
import logging
from contextlib import contextmanager

from Orange.util import Registry

log = logging.getLogger(__name__)


class BackendError(Exception):
pass


class Backend(metaclass=Registry):
"""Base class for SqlTable backends. Implementations should define
all of the methods defined below.
Parameters
----------
connection_params: dict
connection params
"""

display_name = ""

def __init__(self, connection_params):
self.connection_params = connection_params

@classmethod
def available_backends(cls):
"""Return a list of all available backends"""
return cls.registry.values()

# "meta" methods

def list_tables_query(self, schema=None):
"""Return a list of tuples (schema, table_name)
Parameters
----------
schema : Optional[str]
If set, only tables from schema should be listed
Returns
-------
A list of tuples
"""
raise NotImplementedError

def list_tables(self, schema=None):
"""Return a list of tables in database
Parameters
----------
schema : Optional[str]
If set, only tables from given schema will be listed
Returns
-------
A list of TableDesc objects, describing the tables in the database
"""
query = self.list_tables_query(schema)
with self.execute_sql_query(query) as cur:
tables = []
for schema, name in cur.fetchall():
sql = "{}.{}".format(
self.quote_identifier(schema),
self.quote_identifier(name)) if schema else self.quote_identifier(name)
tables.append(TableDesc(name, schema, sql))
return tables

def get_fields(self, table_name):
"""Return a list of field names and metadata in the given table
Parameters
----------
table_name: str
Returns
-------
a list of tuples (field_name, *field_metadata)
both will be passed to create_variable
"""
query = self.create_sql_query(table_name, ["*"], limit=0)
with self.execute_sql_query(query) as cur:
return cur.description

def get_distinct_values(self, field_name, table_name):
"""Return a list of distinct values of field
Parameters
----------
field_name : name of the field
table_name : name of the table or query to search
Returns
-------
List[str] of values
"""
fields = [self.quote_identifier(field_name)]

query = self.create_sql_query(table_name, fields,
group_by=fields, order_by=fields,
limit=21)
with self.execute_sql_query(query) as cur:
values = cur.fetchall()
if len(values) > 20:
return ()
else:
return tuple(str(x[0]) for x in values)

def create_variable(self, field_name, field_metadata,
type_hints, inspect_table=None):
"""Create variable based on field information
Parameters
----------
field_name : str
name do the field
field_metadata : tuple
data to guess field type from
type_hints : Domain
domain with variable templates
inspect_table : Option[str]
name of the table to expect the field values or None
if no inspection is to be performed
Returns
-------
Variable representing the field
"""
raise NotImplementedError

def count_approx(self, query):
"""Return estimated number of rows returned by query.
Parameters
----------
query : str
Returns
-------
Approximate number of rows
"""
raise NotImplementedError

# query related methods

def create_sql_query(
self, table_name, fields, filters=(),
group_by=None, order_by=None, offset=None, limit=None,
use_time_sample=None):
"""Construct an sql query using the provided elements.
Parameters
----------
table_name : str
fields : List[str]
filters : List[str]
group_by: List[str]
order_by: List[str]
offset: int
limit: int
use_time_sample: int
Returns
-------
string containing sql query
"""
raise NotImplementedError

@contextmanager
def execute_sql_query(self, query, params=None):
"""Context manager for execution of sql queries
Usage:
```
with backend.execute_sql_query("SELECT * FROM foo") as cur:
cur.fetch_all()
```
Parameters
----------
query : string
query to be executed
params: tuple
parameters to be passed to the query
Returns
-------
yields a cursor that can be used to access the data
"""
raise NotImplementedError

def quote_identifier(self, name):
"""Quote identifier name so it can be safely used in queries
Parameters
----------
name: str
name of the parameter
Returns
-------
quoted parameter that can be used in sql queries
"""
raise NotImplementedError

def unquote_identifier(self, quoted_name):
"""Remove quotes from identifier name
Used when sql table name is used in where parameter to
query special tables
Parameters
----------
quoted_name : str
Returns
-------
unquoted name
"""
raise NotImplementedError


class TableDesc:
def __init__(self, name, schema, sql):
self.name = name
self.schema = schema
self.sql = sql

def __str__(self):
return self.name

class ToSql:
def __init__(self, sql):
self.sql = sql

def __call__(self):
return self.sql
118 changes: 118 additions & 0 deletions Orange/data/sql/backend/mssql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from contextlib import contextmanager

import pymssql

from Orange.data import StringVariable, TimeVariable, ContinuousVariable, DiscreteVariable
from Orange.data.sql.backend import Backend
from Orange.data.sql.backend.base import ToSql, BackendError


class PymssqlBackend(Backend):
display_name = "SQL Server"

def __init__(self, connection_params):
connection_params["server"] = connection_params.pop("host", None)

for key in list(connection_params):
if connection_params[key] is None:
del connection_params[key]

super().__init__(connection_params)
try:
self.connection = pymssql.connect(**connection_params)
except pymssql.Error as ex:
raise BackendError(str(ex)) from ex

def list_tables_query(self, schema=None):
return """
SELECT [TABLE_SCHEMA], [TABLE_NAME]
FROM information_schema.tables
WHERE TABLE_TYPE='BASE TABLE'
ORDER BY [TABLE_NAME]
"""

def quote_identifier(self, name):
return "[{}]".format(name)

def unquote_identifier(self, quoted_name):
return quoted_name[1:-1]

def create_sql_query(self, table_name, fields, filters=(),
group_by=None, order_by=None, offset=None, limit=None,
use_time_sample=None):
sql = ["SELECT"]
if limit and not offset:
sql.extend(["TOP", str(limit)])
sql.append(', '.join(fields))
sql.extend(["FROM", table_name])
if use_time_sample:
sql.append("TABLESAMPLE system_time(%i)" % use_time_sample)
if filters:
sql.extend(["WHERE", " AND ".join(filters)])
if group_by:
sql.extend(["GROUP BY", ", ".join(group_by)])

if offset and not order_by:
order_by = fields[0].split("AS")[1:]

if order_by:
sql.extend(["ORDER BY", ",".join(order_by)])
if offset:
sql.extend(["OFFSET", str(offset), "ROWS"])
if limit:
sql.extend(["FETCH FIRST", str(limit), "ROWS ONLY"])

return " ".join(sql)

@contextmanager
def execute_sql_query(self, query, params=()):
print(query)
try:
with self.connection.cursor() as cur:
cur.execute(query, *params)
yield cur
finally:
self.connection.commit()

def create_variable(self, field_name, field_metadata, type_hints, inspect_table=None):
if field_name in type_hints:
var = type_hints[field_name]
else:
var = self._guess_variable(field_name, field_metadata,
inspect_table)

field_name_q = self.quote_identifier(field_name)
if var.is_continuous:
if isinstance(var, TimeVariable):
var.to_sql = ToSql("DATEDIFF(s, '1970-01-01 00:00:00', {})".format(field_name_q))
else:
var.to_sql = ToSql(field_name_q)
else: # discrete or string
var.to_sql = ToSql(field_name_q)
return var

def _guess_variable(self, field_name, field_metadata, inspect_table):
from pymssql import STRING, NUMBER, DATETIME, DECIMAL

type_code, *rest = field_metadata

if type_code in (NUMBER, DECIMAL):
return ContinuousVariable(field_name)

if type_code == DATETIME:
tv = TimeVariable(field_name)
tv.have_date = True
tv.have_time = True
return tv

if type_code == STRING:
if inspect_table:
values = [] #self._get_distinct_values(field_name, inspect_table)
if values:
return DiscreteVariable(field_name, values)

return StringVariable(field_name)

def count_approx(self, query):
# TODO: Figure out how to do count estimates on mssql
raise NotImplementedError
Loading