-
-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1674 from astaric/refactor-sql
[ENH] SQL Server support in SQL widget
- Loading branch information
Showing
12 changed files
with
683 additions
and
326 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +0,0 @@ | ||
import Orange.misc | ||
psycopg2 = Orange.misc.import_late_warning("psycopg2") | ||
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from .base import Backend | ||
|
||
try: | ||
from .postgres import Psycopg2Backend | ||
except ImportError: | ||
pass | ||
|
||
try: | ||
from .mssql import PymssqlBackend | ||
except ImportError: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,238 @@ | ||
import logging | ||
from contextlib import contextmanager | ||
|
||
from Orange.util import Registry | ||
|
||
log = logging.getLogger(__name__) | ||
|
||
|
||
class BackendError(Exception): | ||
pass | ||
|
||
|
||
class Backend(metaclass=Registry): | ||
"""Base class for SqlTable backends. Implementations should define | ||
all of the methods defined below. | ||
Parameters | ||
---------- | ||
connection_params: dict | ||
connection params | ||
""" | ||
|
||
display_name = "" | ||
|
||
def __init__(self, connection_params): | ||
self.connection_params = connection_params | ||
|
||
@classmethod | ||
def available_backends(cls): | ||
"""Return a list of all available backends""" | ||
return cls.registry.values() | ||
|
||
# "meta" methods | ||
|
||
def list_tables_query(self, schema=None): | ||
"""Return a list of tuples (schema, table_name) | ||
Parameters | ||
---------- | ||
schema : Optional[str] | ||
If set, only tables from schema should be listed | ||
Returns | ||
------- | ||
A list of tuples | ||
""" | ||
raise NotImplementedError | ||
|
||
def list_tables(self, schema=None): | ||
"""Return a list of tables in database | ||
Parameters | ||
---------- | ||
schema : Optional[str] | ||
If set, only tables from given schema will be listed | ||
Returns | ||
------- | ||
A list of TableDesc objects, describing the tables in the database | ||
""" | ||
query = self.list_tables_query(schema) | ||
with self.execute_sql_query(query) as cur: | ||
tables = [] | ||
for schema, name in cur.fetchall(): | ||
sql = "{}.{}".format( | ||
self.quote_identifier(schema), | ||
self.quote_identifier(name)) if schema else self.quote_identifier(name) | ||
tables.append(TableDesc(name, schema, sql)) | ||
return tables | ||
|
||
def get_fields(self, table_name): | ||
"""Return a list of field names and metadata in the given table | ||
Parameters | ||
---------- | ||
table_name: str | ||
Returns | ||
------- | ||
a list of tuples (field_name, *field_metadata) | ||
both will be passed to create_variable | ||
""" | ||
query = self.create_sql_query(table_name, ["*"], limit=0) | ||
with self.execute_sql_query(query) as cur: | ||
return cur.description | ||
|
||
def get_distinct_values(self, field_name, table_name): | ||
"""Return a list of distinct values of field | ||
Parameters | ||
---------- | ||
field_name : name of the field | ||
table_name : name of the table or query to search | ||
Returns | ||
------- | ||
List[str] of values | ||
""" | ||
fields = [self.quote_identifier(field_name)] | ||
|
||
query = self.create_sql_query(table_name, fields, | ||
group_by=fields, order_by=fields, | ||
limit=21) | ||
with self.execute_sql_query(query) as cur: | ||
values = cur.fetchall() | ||
if len(values) > 20: | ||
return () | ||
else: | ||
return tuple(str(x[0]) for x in values) | ||
|
||
def create_variable(self, field_name, field_metadata, | ||
type_hints, inspect_table=None): | ||
"""Create variable based on field information | ||
Parameters | ||
---------- | ||
field_name : str | ||
name do the field | ||
field_metadata : tuple | ||
data to guess field type from | ||
type_hints : Domain | ||
domain with variable templates | ||
inspect_table : Option[str] | ||
name of the table to expect the field values or None | ||
if no inspection is to be performed | ||
Returns | ||
------- | ||
Variable representing the field | ||
""" | ||
raise NotImplementedError | ||
|
||
def count_approx(self, query): | ||
"""Return estimated number of rows returned by query. | ||
Parameters | ||
---------- | ||
query : str | ||
Returns | ||
------- | ||
Approximate number of rows | ||
""" | ||
raise NotImplementedError | ||
|
||
# query related methods | ||
|
||
def create_sql_query( | ||
self, table_name, fields, filters=(), | ||
group_by=None, order_by=None, offset=None, limit=None, | ||
use_time_sample=None): | ||
"""Construct an sql query using the provided elements. | ||
Parameters | ||
---------- | ||
table_name : str | ||
fields : List[str] | ||
filters : List[str] | ||
group_by: List[str] | ||
order_by: List[str] | ||
offset: int | ||
limit: int | ||
use_time_sample: int | ||
Returns | ||
------- | ||
string containing sql query | ||
""" | ||
raise NotImplementedError | ||
|
||
@contextmanager | ||
def execute_sql_query(self, query, params=None): | ||
"""Context manager for execution of sql queries | ||
Usage: | ||
``` | ||
with backend.execute_sql_query("SELECT * FROM foo") as cur: | ||
cur.fetch_all() | ||
``` | ||
Parameters | ||
---------- | ||
query : string | ||
query to be executed | ||
params: tuple | ||
parameters to be passed to the query | ||
Returns | ||
------- | ||
yields a cursor that can be used to access the data | ||
""" | ||
raise NotImplementedError | ||
|
||
def quote_identifier(self, name): | ||
"""Quote identifier name so it can be safely used in queries | ||
Parameters | ||
---------- | ||
name: str | ||
name of the parameter | ||
Returns | ||
------- | ||
quoted parameter that can be used in sql queries | ||
""" | ||
raise NotImplementedError | ||
|
||
def unquote_identifier(self, quoted_name): | ||
"""Remove quotes from identifier name | ||
Used when sql table name is used in where parameter to | ||
query special tables | ||
Parameters | ||
---------- | ||
quoted_name : str | ||
Returns | ||
------- | ||
unquoted name | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class TableDesc: | ||
def __init__(self, name, schema, sql): | ||
self.name = name | ||
self.schema = schema | ||
self.sql = sql | ||
|
||
def __str__(self): | ||
return self.name | ||
|
||
class ToSql: | ||
def __init__(self, sql): | ||
self.sql = sql | ||
|
||
def __call__(self): | ||
return self.sql |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
from contextlib import contextmanager | ||
|
||
import pymssql | ||
|
||
from Orange.data import StringVariable, TimeVariable, ContinuousVariable, DiscreteVariable | ||
from Orange.data.sql.backend import Backend | ||
from Orange.data.sql.backend.base import ToSql, BackendError | ||
|
||
|
||
class PymssqlBackend(Backend): | ||
display_name = "SQL Server" | ||
|
||
def __init__(self, connection_params): | ||
connection_params["server"] = connection_params.pop("host", None) | ||
|
||
for key in list(connection_params): | ||
if connection_params[key] is None: | ||
del connection_params[key] | ||
|
||
super().__init__(connection_params) | ||
try: | ||
self.connection = pymssql.connect(**connection_params) | ||
except pymssql.Error as ex: | ||
raise BackendError(str(ex)) from ex | ||
|
||
def list_tables_query(self, schema=None): | ||
return """ | ||
SELECT [TABLE_SCHEMA], [TABLE_NAME] | ||
FROM information_schema.tables | ||
WHERE TABLE_TYPE='BASE TABLE' | ||
ORDER BY [TABLE_NAME] | ||
""" | ||
|
||
def quote_identifier(self, name): | ||
return "[{}]".format(name) | ||
|
||
def unquote_identifier(self, quoted_name): | ||
return quoted_name[1:-1] | ||
|
||
def create_sql_query(self, table_name, fields, filters=(), | ||
group_by=None, order_by=None, offset=None, limit=None, | ||
use_time_sample=None): | ||
sql = ["SELECT"] | ||
if limit and not offset: | ||
sql.extend(["TOP", str(limit)]) | ||
sql.append(', '.join(fields)) | ||
sql.extend(["FROM", table_name]) | ||
if use_time_sample: | ||
sql.append("TABLESAMPLE system_time(%i)" % use_time_sample) | ||
if filters: | ||
sql.extend(["WHERE", " AND ".join(filters)]) | ||
if group_by: | ||
sql.extend(["GROUP BY", ", ".join(group_by)]) | ||
|
||
if offset and not order_by: | ||
order_by = fields[0].split("AS")[1:] | ||
|
||
if order_by: | ||
sql.extend(["ORDER BY", ",".join(order_by)]) | ||
if offset: | ||
sql.extend(["OFFSET", str(offset), "ROWS"]) | ||
if limit: | ||
sql.extend(["FETCH FIRST", str(limit), "ROWS ONLY"]) | ||
|
||
return " ".join(sql) | ||
|
||
@contextmanager | ||
def execute_sql_query(self, query, params=()): | ||
print(query) | ||
try: | ||
with self.connection.cursor() as cur: | ||
cur.execute(query, *params) | ||
yield cur | ||
finally: | ||
self.connection.commit() | ||
|
||
def create_variable(self, field_name, field_metadata, type_hints, inspect_table=None): | ||
if field_name in type_hints: | ||
var = type_hints[field_name] | ||
else: | ||
var = self._guess_variable(field_name, field_metadata, | ||
inspect_table) | ||
|
||
field_name_q = self.quote_identifier(field_name) | ||
if var.is_continuous: | ||
if isinstance(var, TimeVariable): | ||
var.to_sql = ToSql("DATEDIFF(s, '1970-01-01 00:00:00', {})".format(field_name_q)) | ||
else: | ||
var.to_sql = ToSql(field_name_q) | ||
else: # discrete or string | ||
var.to_sql = ToSql(field_name_q) | ||
return var | ||
|
||
def _guess_variable(self, field_name, field_metadata, inspect_table): | ||
from pymssql import STRING, NUMBER, DATETIME, DECIMAL | ||
|
||
type_code, *rest = field_metadata | ||
|
||
if type_code in (NUMBER, DECIMAL): | ||
return ContinuousVariable(field_name) | ||
|
||
if type_code == DATETIME: | ||
tv = TimeVariable(field_name) | ||
tv.have_date = True | ||
tv.have_time = True | ||
return tv | ||
|
||
if type_code == STRING: | ||
if inspect_table: | ||
values = [] #self._get_distinct_values(field_name, inspect_table) | ||
if values: | ||
return DiscreteVariable(field_name, values) | ||
|
||
return StringVariable(field_name) | ||
|
||
def count_approx(self, query): | ||
# TODO: Figure out how to do count estimates on mssql | ||
raise NotImplementedError |
Oops, something went wrong.