diff --git a/doc/us/manual.html b/doc/us/manual.html
index d93c4c9..7754314 100644
--- a/doc/us/manual.html
+++ b/doc/us/manual.html
@@ -219,8 +219,10 @@
Methods
the operation could not be performed or when it is not implemented.
- conn:execute(statement)
- Executes the given SQL statement
.
+ conn:execute(statement[,...])
+ Executes the given SQL statement
. As in traditional prepared statements,
+ additional parameters can be used to avoid SQL injections. Although this is only
+ supported by sqlite3, postgres and mysql drivers.
Returns: a cursor object
if there are results, or the number of rows affected by the command otherwise.
diff --git a/src/ls_mysql.c b/src/ls_mysql.c
index 1e01c3d..48d959d 100644
--- a/src/ls_mysql.c
+++ b/src/ls_mysql.c
@@ -10,6 +10,7 @@
#include
#include
#include
+#include
#ifdef WIN32
#include
@@ -71,13 +72,25 @@ typedef struct {
} conn_data;
typedef struct {
- short closed;
- int conn; /* reference to connection */
- int numcols; /* number of columns */
- int colnames, coltypes; /* reference to column information tables */
- MYSQL_RES *my_res;
+ short closed;
+ int conn; /* reference to connection */
+ int numcols; /* number of columns */
+ int colnames, coltypes; /* reference to column information tables */
+ MYSQL_RES *my_res;
+ MYSQL_STMT *stmt;
+ MYSQL_BIND *params; /* bound to result columns */
+ unsigned long *real_lengths; /* params[i].length will point to these real_lengths */
+ bool *nulls; /* buffer for is_null */
+ bool *errors; /* buffer for error */
} cur_data;
+typedef union {
+ double number;
+ size_t size;
+ long long int longlong;
+ char c;
+} column_data;
+
LUASQL_API int luaopen_luasql_mysql (lua_State *L);
@@ -117,11 +130,22 @@ static cur_data *getcursor (lua_State *L) {
/*
** Push the value of #i field of #tuple row.
*/
-static void pushvalue (lua_State *L, void *row, long int len) {
- if (row == NULL)
- lua_pushnil (L);
- else
- lua_pushlstring (L, row, len);
+static void pushvalue (lua_State *L, cur_data *cur, int i) {
+ if (cur->nulls[i]) {
+ lua_pushnil(L);
+ cur->nulls[i] = 0;
+ } else {
+ /* error flags are set whenever lengths differ, but we resize only when real lengths are bigger */
+ if (cur->errors[i]) {
+ if (cur->real_lengths[i] > cur->params[i].buffer_length) {
+ cur->params[i].buffer = realloc(cur->params[i].buffer, cur->real_lengths[i]);
+ cur->params[i].buffer_length = cur->real_lengths[i];
+ }
+ mysql_stmt_fetch_column(cur->stmt, &cur->params[i], i, 0);
+ cur->errors[i] = 0;
+ }
+ lua_pushlstring(L, cur->params[i].buffer, cur->real_lengths[i]);
+ }
}
@@ -185,38 +209,40 @@ static void create_colinfo (lua_State *L, cur_data *cur) {
** Closes the cursos and nullify all structure fields.
*/
static void cur_nullify (lua_State *L, cur_data *cur) {
+ int i;
/* Nullify structure fields. */
cur->closed = 1;
mysql_free_result(cur->my_res);
+ mysql_stmt_close(cur->stmt);
+ for (i = 0; i < cur->numcols; i++) {
+ free(cur->params[i].buffer);
+ cur->params[i].buffer = NULL;
+ }
luaL_unref (L, LUA_REGISTRYINDEX, cur->conn);
luaL_unref (L, LUA_REGISTRYINDEX, cur->colnames);
luaL_unref (L, LUA_REGISTRYINDEX, cur->coltypes);
}
-
+
/*
** Get another row of the given cursor.
*/
static int cur_fetch (lua_State *L) {
cur_data *cur = getcursor (L);
- MYSQL_RES *res = cur->my_res;
- unsigned long *lengths;
- MYSQL_ROW row = mysql_fetch_row(res);
- if (row == NULL) {
- cur_nullify (L, cur);
+ int r = mysql_stmt_fetch(cur->stmt);
+ if (r && r != MYSQL_DATA_TRUNCATED) {
+ cur_nullify(L, cur);
lua_pushnil(L); /* no more results */
return 1;
}
- lengths = mysql_fetch_lengths(res);
-
if (lua_istable (L, 2)) {
const char *opts = luaL_optstring (L, 3, "n");
if (strchr (opts, 'n') != NULL) {
/* Copy values to numerical indices */
int i;
for (i = 0; i < cur->numcols; i++) {
- pushvalue (L, row[i], lengths[i]);
- lua_rawseti (L, 2, i+1);
+ pushvalue(L, cur, i);
+ lua_rawseti(L, 2, i+1);
}
}
if (strchr (opts, 'a') != NULL) {
@@ -231,7 +257,7 @@ static int cur_fetch (lua_State *L) {
lua_rawgeti(L, -1, i+1); /* push the field name */
/* Actually push the value */
- pushvalue (L, row[i], lengths[i]);
+ pushvalue (L, cur, i);
lua_rawset (L, 2);
}
/* lua_pop(L, 1); Pops colnames table. Not needed */
@@ -243,7 +269,7 @@ static int cur_fetch (lua_State *L) {
int i;
luaL_checkstack (L, cur->numcols, LUASQL_PREFIX"too many columns");
for (i = 0; i < cur->numcols; i++)
- pushvalue (L, row[i], lengths[i]);
+ pushvalue (L, cur, i);
return cur->numcols; /* return #numcols values */
}
}
@@ -317,7 +343,7 @@ static int cur_getcoltypes (lua_State *L) {
** Push the number of rows.
*/
static int cur_numrows (lua_State *L) {
- lua_pushinteger (L, (lua_Number)mysql_num_rows (getcursor(L)->my_res));
+ lua_pushinteger (L, (lua_Number)mysql_stmt_num_rows (getcursor(L)->stmt));
return 1;
}
@@ -325,9 +351,12 @@ static int cur_numrows (lua_State *L) {
/*
** Create a new Cursor object and push it on top of the stack.
*/
-static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) {
- cur_data *cur = (cur_data *)lua_newuserdata(L, sizeof(cur_data));
+static int create_cursor (lua_State *L, int conn, MYSQL_STMT *stmt, MYSQL_RES *result, int cols) {
+ int i;
+ size_t memsize = sizeof (cur_data) + cols * (sizeof (MYSQL_BIND) + sizeof (unsigned long) + 2 * sizeof (bool));
+ cur_data *cur = (cur_data *)lua_newuserdata(L, memsize);
luasql_setmeta (L, LUASQL_CURSOR_MYSQL);
+ memset(cur, 0, memsize);
/* fill in structure */
cur->closed = 0;
@@ -336,8 +365,28 @@ static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) {
cur->colnames = LUA_NOREF;
cur->coltypes = LUA_NOREF;
cur->my_res = result;
+ cur->stmt = stmt;
lua_pushvalue (L, conn);
cur->conn = luaL_ref (L, LUA_REGISTRYINDEX);
+ cur->params = (MYSQL_BIND *)(sizeof (cur_data) + (char *)cur); /* after cur */
+ cur->real_lengths = (unsigned long *)(cols * sizeof (MYSQL_BIND) + (char *)cur->params);
+ cur->nulls = (bool*)(cols * sizeof (unsigned long) + (char *)cur->real_lengths);
+ cur->errors = (bool*)(cols * sizeof (bool) + (char *)cur->nulls);
+ for (i = 0; i < cols; i++) {
+ cur->params[i].buffer_type = MYSQL_TYPE_STRING;
+ cur->params[i].buffer = malloc(0);
+ cur->params[i].buffer_length = 0;
+ cur->params[i].length = &cur->real_lengths[i];
+ /* old versions use my_bool, newer use bool. There's no simple way to detect it */
+ cur->params[i].is_null = (void*)&cur->nulls[i];
+ cur->params[i].error = (void*)&cur->errors[i];
+ }
+
+ if (mysql_stmt_bind_result(cur->stmt, cur->params)) {
+ int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt));
+ cur_nullify(L, cur);
+ return n;
+ }
return 1;
}
@@ -417,27 +466,103 @@ static int conn_execute (lua_State *L) {
conn_data *conn = getconnection (L);
size_t st_len;
const char *statement = luaL_checklstring (L, 2, &st_len);
- if (mysql_real_query(conn->my_conn, statement, st_len))
- /* error executing query */
- return luasql_failmsg(L, "error executing query. MySQL: ", mysql_error(conn->my_conn));
- else
- {
- MYSQL_RES *res = mysql_store_result(conn->my_conn);
- unsigned int num_cols = mysql_field_count(conn->my_conn);
-
- if (res) { /* tuples returned */
- return create_cursor (L, 1, res, num_cols);
- }
- else { /* mysql_use_result() returned nothing; should it have? */
- if(num_cols == 0) { /* no tuples returned */
- /* query does not return data (it was not a SELECT) */
- lua_pushinteger(L, mysql_affected_rows(conn->my_conn));
- return 1;
- }
- else /* mysql_use_result() should have returned data */
- return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_error(conn->my_conn));
+ int i, nparams = lua_gettop(L);
+ MYSQL_STMT * stmt;
+ MYSQL_BIND * params;
+ column_data * params_data;
+ MYSQL_RES * res;
+ unsigned int num_cols;
+
+ stmt = mysql_stmt_init(conn->my_conn);
+ if (stmt == NULL)
+ return luasql_failmsg(L, "error executing query (stmt_init). MySQL: ", mysql_error(conn->my_conn));
+ if (mysql_stmt_prepare(stmt, statement, st_len)) {
+ int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt));
+ mysql_stmt_close(stmt);
+ return n;
+ }
+ if (nparams - 2 != mysql_stmt_param_count(stmt)) {
+ mysql_stmt_close(stmt);
+ return luasql_faildirect(L, "error executing query. Invalid parameter count");
+ }
+ params = calloc(sizeof (MYSQL_BIND), nparams - 2);
+ params_data = calloc(sizeof (column_data), nparams - 2);
+ for (i = 3; i <= nparams; i++) {
+ switch (lua_type(L, i)) {
+ case LUA_TNIL:
+ params[i-3].buffer_type = MYSQL_TYPE_NULL;
+ break;
+ case LUA_TBOOLEAN:
+ params_data[i-3].c = lua_toboolean(L, i);
+ params[i-3].buffer_type = MYSQL_TYPE_TINY;
+ params[i-3].buffer = ¶ms_data[i-3].c;
+ params[i-3].buffer_length = sizeof (char);
+ break;
+ case LUA_TNUMBER:
+#ifdef LUA_INT_TYPE
+ if (lua_isinteger(L, i)) {
+ params_data[i-3].longlong = lua_tointeger(L, i);
+ params[i-3].buffer_type = MYSQL_TYPE_LONGLONG;
+ params[i-3].buffer = ¶ms_data[i-3].longlong;
+ params[i-3].buffer_length = sizeof (long long int);
+ break;
+ }
+#endif
+ params_data[i-3].number = lua_tonumber(L, i);
+ params[i-3].buffer_type = MYSQL_TYPE_DOUBLE;
+ params[i-3].buffer = ¶ms_data[i-3].number;
+ params[i-3].buffer_length = sizeof (double);
+ break;
+ case LUA_TSTRING:
+ params[i-3].buffer_type = MYSQL_TYPE_STRING;
+ params[i-3].buffer = (char*)lua_tolstring(L, i, ¶ms_data[i-3].size);
+ params[i-3].buffer_length = params_data[i-3].size;
+ params[i-3].length = ¶ms_data[i-3].size;
+ break;
+ default:
+ free(params);
+ free(params_data);
+ mysql_stmt_close(stmt);
+ return luasql_faildirect(L, "error executing query. Invalid parameter type");
}
}
+ if (mysql_stmt_bind_param(stmt, params)) {
+ int n = luasql_failmsg(L, "error executing query (stmt_bind_param). MySQL: ", mysql_stmt_error(stmt));
+ free(params);
+ free(params_data);
+ mysql_stmt_close(stmt);
+ return n;
+ }
+ if (mysql_stmt_execute(stmt)) {
+ int n = luasql_failmsg(L, "error executing query (stmt_execute). MySQL: ", mysql_stmt_error(stmt));
+ free(params);
+ free(params_data);
+ mysql_stmt_close(stmt);
+ return n;
+ }
+ free(params);
+ free(params_data);
+ if (mysql_stmt_store_result(stmt)) {
+ int n = luasql_failmsg(L, "error executing query (stmt_store_result). MySQL: ", mysql_stmt_error(stmt));
+ mysql_stmt_close(stmt);
+ return n;
+ }
+
+ res = mysql_stmt_result_metadata(stmt);
+ num_cols = mysql_stmt_field_count(stmt);
+ if (res) { /* tuples returned */
+ return create_cursor (L, 1, stmt, res, num_cols);
+ }
+
+ if(num_cols == 0) { /* no tuples returned */
+ /* query does not return data (it was not a SELECT) */
+ lua_pushinteger(L, mysql_stmt_affected_rows(stmt));
+ mysql_stmt_close(stmt);
+ return 1;
+ } else { /* mysql_use_result() should have returned data */
+ mysql_stmt_close(stmt);
+ return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_stmt_error(stmt));
+ }
}
diff --git a/src/ls_postgres.c b/src/ls_postgres.c
index dd97ea7..85162dc 100644
--- a/src/ls_postgres.c
+++ b/src/ls_postgres.c
@@ -414,7 +414,20 @@ static int conn_escape (lua_State *L) {
static int conn_execute (lua_State *L) {
conn_data *conn = getconnection (L);
const char *statement = luaL_checkstring (L, 2);
- PGresult *res = PQexec(conn->pg_conn, statement);
+ int nparams = lua_gettop(L);
+ PGresult *res;
+ if (nparams > 2) {
+ int i;
+ const char ** values = malloc(sizeof (char *) * (nparams - 2));
+ for (i = 3; i <= nparams; i++)
+ values[i - 3] = lua_tostring(L, i);
+ res = PQexecParams(conn->pg_conn, statement, nparams - 2, NULL, values, NULL, NULL, 0);
+ free(values);
+ }
+ else {
+ /* for multiple statements support */
+ res = PQexec(conn->pg_conn, statement);
+ }
if (res && PQresultStatus(res)==PGRES_COMMAND_OK) {
/* no tuples returned */
lua_pushnumber(L, atof(PQcmdTuples(res)));
diff --git a/src/ls_sqlite3.c b/src/ls_sqlite3.c
index 33672ca..a9b9f2b 100644
--- a/src/ls_sqlite3.c
+++ b/src/ls_sqlite3.c
@@ -379,6 +379,7 @@ static int conn_escape(lua_State *L)
*/
static int conn_execute(lua_State *L)
{
+ int i;
conn_data *conn = getconnection(L);
const char *statement = luaL_checkstring(L, 2);
int res;
@@ -398,6 +399,52 @@ static int conn_execute(lua_State *L)
return luasql_faildirect(L, errmsg);
}
+ /* bind any additional arguments to the statement */
+ numcols = lua_gettop(L);
+ for (i = 3; i <= numcols; i++)
+ {
+ const char * buffer;
+ size_t size;
+ switch (lua_type(L, i)) {
+ case LUA_TNIL:
+ res = sqlite3_bind_null(vm, i - 2);
+ break;
+
+ case LUA_TBOOLEAN:
+ case LUA_TNUMBER:
+#ifdef LUA_INT_TYPE
+ if (lua_isnumber(L, i) && !lua_isinteger(L, i))
+ {
+ res = sqlite3_bind_double(vm, i - 2, lua_tonumber(L, i));
+ }
+ else
+ {
+ res = sqlite3_bind_int64(vm, i - 2, lua_tointeger(L, i));
+ }
+#else
+ res = sqlite3_bind_double(vm, i - 2, lua_tonumber(L, i));
+#endif
+ break;
+
+ case LUA_TSTRING:
+ buffer = lua_tolstring(L, i, &size);
+ res = sqlite3_bind_text(vm, i - 2, buffer, size, SQLITE_TRANSIENT);
+ break;
+
+ default:
+ sqlite3_finalize(vm);
+ return luaL_error(L, LUASQL_PREFIX"Invalid type for execute parameter %d", i - 2);
+ }
+
+ /* handle errors */
+ if (res != SQLITE_OK)
+ {
+ errmsg = sqlite3_errmsg(conn->sql_conn);
+ sqlite3_finalize(vm);
+ return luaL_error(L, LUASQL_PREFIX"Error binding parameter %d: %s", i - 2, errmsg);
+ }
+ }
+
/* process first result to retrive query information and type */
res = sqlite3_step(vm);
numcols = sqlite3_column_count(vm);
diff --git a/src/luasql.c b/src/luasql.c
index ed1cae7..ecf463a 100644
--- a/src/luasql.c
+++ b/src/luasql.c
@@ -131,3 +131,35 @@ LUASQL_API void luasql_set_info (lua_State *L) {
lua_pushliteral (L, "LuaSQL 2.3.5 (for "LUA_VERSION")");
lua_settable (L, -3);
}
+
+/*
+** Execute an SQL statement from a string.
+** Return a Cursor object if the statement is a query, otherwise
+** return the number of tuples affected by the statement.
+** It's nothing more than a C implementation of:
+** function conn:execute(sql, ...)
+** local stmt, msg = conn:prepare(sql)
+** if stmt == nil then return nil, msg end
+** return stmt:execute(...)
+** end
+*/
+LUASQL_API int luasql_conn_execute (lua_State *L) {
+ // stack: conn sql ...
+ lua_getfield(L, 1, "prepare");
+ lua_pushvalue(L, 1);
+ lua_pushvalue(L, 2);
+ // stack: conn sql ... conn.prepare conn sql
+ lua_call(L, 2, 2);
+ // stack: conn sql ... stmt msg
+ if (lua_isnil(L, -2))
+ return 2;
+ lua_pop(L, 1);
+ lua_replace(L, 2);
+ // stack: conn stmt ...
+ lua_getfield(L, 2, "execute");
+ lua_replace(L, 1);
+ // stack: stmt.execute stmt ...
+ lua_call(L, lua_gettop(L)-1, LUA_MULTRET);
+ // stack: cur msg?
+ return lua_gettop(L);
+}
diff --git a/src/luasql.h b/src/luasql.h
index 345bf57..d7fe655 100644
--- a/src/luasql.h
+++ b/src/luasql.h
@@ -29,6 +29,7 @@ LUASQL_API int luasql_failmsg (lua_State *L, const char *err, const char *m);
LUASQL_API int luasql_createmeta (lua_State *L, const char *name, const luaL_Reg *methods);
LUASQL_API void luasql_setmeta (lua_State *L, const char *name);
LUASQL_API void luasql_set_info (lua_State *L);
+LUASQL_API int luasql_conn_execute (lua_State *L);
#if !defined LUA_VERSION_NUM || LUA_VERSION_NUM==501
void luaL_setfuncs (lua_State *L, const luaL_Reg *l, int nup);