Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic support for prepared statements in postgres, sqlite3 and mysql. #99

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doc/us/manual.html
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,10 @@ <h4>Methods</h4>
the operation could not be performed or when it is not implemented.</dd>


<dt><a name="conn_execute"></a><strong><code>conn:execute(statement)</code></strong></dt>
<dd>Executes the given SQL <code>statement</code>.<br/>
<dt><a name="conn_execute"></a><strong><code>conn:execute(statement[,...])</code></strong></dt>
<dd>Executes the given SQL <code>statement</code>. As in traditional prepared statements,
additional parameters can be used to avoid SQL injections. Although this is only
supported by sqlite3, postgres and mysql drivers.<br/>
Returns: a <a href="#cursor_object">cursor object</a>
if there are results, or the number of rows affected by the command otherwise.</dd>

Expand Down
213 changes: 169 additions & 44 deletions src/ls_mysql.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <stdlib.h>
#include <string.h>
#include <ctype.h>
#include <stdbool.h>

#ifdef WIN32
#include <winsock2.h>
Expand Down Expand Up @@ -71,13 +72,25 @@ typedef struct {
} conn_data;

typedef struct {
short closed;
int conn; /* reference to connection */
int numcols; /* number of columns */
int colnames, coltypes; /* reference to column information tables */
MYSQL_RES *my_res;
short closed;
int conn; /* reference to connection */
int numcols; /* number of columns */
int colnames, coltypes; /* reference to column information tables */
MYSQL_RES *my_res;
MYSQL_STMT *stmt;
MYSQL_BIND *params; /* bound to result columns */
unsigned long *real_lengths; /* params[i].length will point to these real_lengths */
bool *nulls; /* buffer for is_null */
bool *errors; /* buffer for error */
} cur_data;

typedef union {
double number;
size_t size;
long long int longlong;
char c;
} column_data;

LUASQL_API int luaopen_luasql_mysql (lua_State *L);


Expand Down Expand Up @@ -117,11 +130,22 @@ static cur_data *getcursor (lua_State *L) {
/*
** Push the value of #i field of #tuple row.
*/
static void pushvalue (lua_State *L, void *row, long int len) {
if (row == NULL)
lua_pushnil (L);
else
lua_pushlstring (L, row, len);
static void pushvalue (lua_State *L, cur_data *cur, int i) {
if (cur->nulls[i]) {
lua_pushnil(L);
cur->nulls[i] = 0;
} else {
/* error flags are set whenever lengths differ, but we resize only when real lengths are bigger */
if (cur->errors[i]) {
if (cur->real_lengths[i] > cur->params[i].buffer_length) {
cur->params[i].buffer = realloc(cur->params[i].buffer, cur->real_lengths[i]);
cur->params[i].buffer_length = cur->real_lengths[i];
}
mysql_stmt_fetch_column(cur->stmt, &cur->params[i], i, 0);
cur->errors[i] = 0;
}
lua_pushlstring(L, cur->params[i].buffer, cur->real_lengths[i]);
}
}


Expand Down Expand Up @@ -185,38 +209,40 @@ static void create_colinfo (lua_State *L, cur_data *cur) {
** Closes the cursos and nullify all structure fields.
*/
static void cur_nullify (lua_State *L, cur_data *cur) {
int i;
/* Nullify structure fields. */
cur->closed = 1;
mysql_free_result(cur->my_res);
mysql_stmt_close(cur->stmt);
for (i = 0; i < cur->numcols; i++) {
free(cur->params[i].buffer);
cur->params[i].buffer = NULL;
}
luaL_unref (L, LUA_REGISTRYINDEX, cur->conn);
luaL_unref (L, LUA_REGISTRYINDEX, cur->colnames);
luaL_unref (L, LUA_REGISTRYINDEX, cur->coltypes);
}


/*
** Get another row of the given cursor.
*/
static int cur_fetch (lua_State *L) {
cur_data *cur = getcursor (L);
MYSQL_RES *res = cur->my_res;
unsigned long *lengths;
MYSQL_ROW row = mysql_fetch_row(res);
if (row == NULL) {
cur_nullify (L, cur);
int r = mysql_stmt_fetch(cur->stmt);
if (r && r != MYSQL_DATA_TRUNCATED) {
cur_nullify(L, cur);
lua_pushnil(L); /* no more results */
return 1;
}
lengths = mysql_fetch_lengths(res);

if (lua_istable (L, 2)) {
const char *opts = luaL_optstring (L, 3, "n");
if (strchr (opts, 'n') != NULL) {
/* Copy values to numerical indices */
int i;
for (i = 0; i < cur->numcols; i++) {
pushvalue (L, row[i], lengths[i]);
lua_rawseti (L, 2, i+1);
pushvalue(L, cur, i);
lua_rawseti(L, 2, i+1);
}
}
if (strchr (opts, 'a') != NULL) {
Expand All @@ -231,7 +257,7 @@ static int cur_fetch (lua_State *L) {
lua_rawgeti(L, -1, i+1); /* push the field name */

/* Actually push the value */
pushvalue (L, row[i], lengths[i]);
pushvalue (L, cur, i);
lua_rawset (L, 2);
}
/* lua_pop(L, 1); Pops colnames table. Not needed */
Expand All @@ -243,7 +269,7 @@ static int cur_fetch (lua_State *L) {
int i;
luaL_checkstack (L, cur->numcols, LUASQL_PREFIX"too many columns");
for (i = 0; i < cur->numcols; i++)
pushvalue (L, row[i], lengths[i]);
pushvalue (L, cur, i);
return cur->numcols; /* return #numcols values */
}
}
Expand Down Expand Up @@ -317,17 +343,20 @@ static int cur_getcoltypes (lua_State *L) {
** Push the number of rows.
*/
static int cur_numrows (lua_State *L) {
lua_pushinteger (L, (lua_Number)mysql_num_rows (getcursor(L)->my_res));
lua_pushinteger (L, (lua_Number)mysql_stmt_num_rows (getcursor(L)->stmt));
return 1;
}


/*
** Create a new Cursor object and push it on top of the stack.
*/
static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) {
cur_data *cur = (cur_data *)lua_newuserdata(L, sizeof(cur_data));
static int create_cursor (lua_State *L, int conn, MYSQL_STMT *stmt, MYSQL_RES *result, int cols) {
int i;
size_t memsize = sizeof (cur_data) + cols * (sizeof (MYSQL_BIND) + sizeof (unsigned long) + 2 * sizeof (bool));
cur_data *cur = (cur_data *)lua_newuserdata(L, memsize);
luasql_setmeta (L, LUASQL_CURSOR_MYSQL);
memset(cur, 0, memsize);

/* fill in structure */
cur->closed = 0;
Expand All @@ -336,8 +365,28 @@ static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) {
cur->colnames = LUA_NOREF;
cur->coltypes = LUA_NOREF;
cur->my_res = result;
cur->stmt = stmt;
lua_pushvalue (L, conn);
cur->conn = luaL_ref (L, LUA_REGISTRYINDEX);
cur->params = (MYSQL_BIND *)(sizeof (cur_data) + (char *)cur); /* after cur */
cur->real_lengths = (unsigned long *)(cols * sizeof (MYSQL_BIND) + (char *)cur->params);
cur->nulls = (bool*)(cols * sizeof (unsigned long) + (char *)cur->real_lengths);
cur->errors = (bool*)(cols * sizeof (bool) + (char *)cur->nulls);
for (i = 0; i < cols; i++) {
cur->params[i].buffer_type = MYSQL_TYPE_STRING;
cur->params[i].buffer = malloc(0);
cur->params[i].buffer_length = 0;
cur->params[i].length = &cur->real_lengths[i];
/* old versions use my_bool, newer use bool. There's no simple way to detect it */
cur->params[i].is_null = (void*)&cur->nulls[i];
cur->params[i].error = (void*)&cur->errors[i];
}

if (mysql_stmt_bind_result(cur->stmt, cur->params)) {
int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt));
cur_nullify(L, cur);
return n;
}

return 1;
}
Expand Down Expand Up @@ -417,27 +466,103 @@ static int conn_execute (lua_State *L) {
conn_data *conn = getconnection (L);
size_t st_len;
const char *statement = luaL_checklstring (L, 2, &st_len);
if (mysql_real_query(conn->my_conn, statement, st_len))
/* error executing query */
return luasql_failmsg(L, "error executing query. MySQL: ", mysql_error(conn->my_conn));
else
{
MYSQL_RES *res = mysql_store_result(conn->my_conn);
unsigned int num_cols = mysql_field_count(conn->my_conn);

if (res) { /* tuples returned */
return create_cursor (L, 1, res, num_cols);
}
else { /* mysql_use_result() returned nothing; should it have? */
if(num_cols == 0) { /* no tuples returned */
/* query does not return data (it was not a SELECT) */
lua_pushinteger(L, mysql_affected_rows(conn->my_conn));
return 1;
}
else /* mysql_use_result() should have returned data */
return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_error(conn->my_conn));
int i, nparams = lua_gettop(L);
MYSQL_STMT * stmt;
MYSQL_BIND * params;
column_data * params_data;
MYSQL_RES * res;
unsigned int num_cols;

stmt = mysql_stmt_init(conn->my_conn);
if (stmt == NULL)
return luasql_failmsg(L, "error executing query (stmt_init). MySQL: ", mysql_error(conn->my_conn));
if (mysql_stmt_prepare(stmt, statement, st_len)) {
int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt));
mysql_stmt_close(stmt);
return n;
}
if (nparams - 2 != mysql_stmt_param_count(stmt)) {
mysql_stmt_close(stmt);
return luasql_faildirect(L, "error executing query. Invalid parameter count");
}
params = calloc(sizeof (MYSQL_BIND), nparams - 2);
params_data = calloc(sizeof (column_data), nparams - 2);
for (i = 3; i <= nparams; i++) {
switch (lua_type(L, i)) {
case LUA_TNIL:
params[i-3].buffer_type = MYSQL_TYPE_NULL;
break;
case LUA_TBOOLEAN:
params_data[i-3].c = lua_toboolean(L, i);
params[i-3].buffer_type = MYSQL_TYPE_TINY;
params[i-3].buffer = &params_data[i-3].c;
params[i-3].buffer_length = sizeof (char);
break;
case LUA_TNUMBER:
#ifdef LUA_INT_TYPE
if (lua_isinteger(L, i)) {
params_data[i-3].longlong = lua_tointeger(L, i);
params[i-3].buffer_type = MYSQL_TYPE_LONGLONG;
params[i-3].buffer = &params_data[i-3].longlong;
params[i-3].buffer_length = sizeof (long long int);
break;
}
#endif
params_data[i-3].number = lua_tonumber(L, i);
params[i-3].buffer_type = MYSQL_TYPE_DOUBLE;
params[i-3].buffer = &params_data[i-3].number;
params[i-3].buffer_length = sizeof (double);
break;
case LUA_TSTRING:
params[i-3].buffer_type = MYSQL_TYPE_STRING;
params[i-3].buffer = (char*)lua_tolstring(L, i, &params_data[i-3].size);
params[i-3].buffer_length = params_data[i-3].size;
params[i-3].length = &params_data[i-3].size;
break;
default:
free(params);
free(params_data);
mysql_stmt_close(stmt);
return luasql_faildirect(L, "error executing query. Invalid parameter type");
}
}
if (mysql_stmt_bind_param(stmt, params)) {
int n = luasql_failmsg(L, "error executing query (stmt_bind_param). MySQL: ", mysql_stmt_error(stmt));
free(params);
free(params_data);
mysql_stmt_close(stmt);
return n;
}
if (mysql_stmt_execute(stmt)) {
int n = luasql_failmsg(L, "error executing query (stmt_execute). MySQL: ", mysql_stmt_error(stmt));
free(params);
free(params_data);
mysql_stmt_close(stmt);
return n;
}
free(params);
free(params_data);
if (mysql_stmt_store_result(stmt)) {
int n = luasql_failmsg(L, "error executing query (stmt_store_result). MySQL: ", mysql_stmt_error(stmt));
mysql_stmt_close(stmt);
return n;
}

res = mysql_stmt_result_metadata(stmt);
num_cols = mysql_stmt_field_count(stmt);
if (res) { /* tuples returned */
return create_cursor (L, 1, stmt, res, num_cols);
}

if(num_cols == 0) { /* no tuples returned */
/* query does not return data (it was not a SELECT) */
lua_pushinteger(L, mysql_stmt_affected_rows(stmt));
mysql_stmt_close(stmt);
return 1;
} else { /* mysql_use_result() should have returned data */
mysql_stmt_close(stmt);
return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_stmt_error(stmt));
}
}


Expand Down
15 changes: 14 additions & 1 deletion src/ls_postgres.c
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,20 @@ static int conn_escape (lua_State *L) {
static int conn_execute (lua_State *L) {
conn_data *conn = getconnection (L);
const char *statement = luaL_checkstring (L, 2);
PGresult *res = PQexec(conn->pg_conn, statement);
int nparams = lua_gettop(L);
PGresult *res;
if (nparams > 2) {
int i;
const char ** values = malloc(sizeof (char *) * (nparams - 2));
for (i = 3; i <= nparams; i++)
values[i - 3] = lua_tostring(L, i);
res = PQexecParams(conn->pg_conn, statement, nparams - 2, NULL, values, NULL, NULL, 0);
free(values);
}
else {
/* for multiple statements support */
res = PQexec(conn->pg_conn, statement);
}
if (res && PQresultStatus(res)==PGRES_COMMAND_OK) {
/* no tuples returned */
lua_pushnumber(L, atof(PQcmdTuples(res)));
Expand Down
Loading