diff --git a/doc/us/manual.html b/doc/us/manual.html index d93c4c9..7754314 100644 --- a/doc/us/manual.html +++ b/doc/us/manual.html @@ -219,8 +219,10 @@

Methods

the operation could not be performed or when it is not implemented. -
conn:execute(statement)
-
Executes the given SQL statement.
+
conn:execute(statement[,...])
+
Executes the given SQL statement. As in traditional prepared statements, + additional parameters can be used to avoid SQL injections. Although this is only + supported by sqlite3, postgres and mysql drivers.
Returns: a cursor object if there are results, or the number of rows affected by the command otherwise.
diff --git a/src/ls_mysql.c b/src/ls_mysql.c index 1e01c3d..48d959d 100644 --- a/src/ls_mysql.c +++ b/src/ls_mysql.c @@ -10,6 +10,7 @@ #include #include #include +#include #ifdef WIN32 #include @@ -71,13 +72,25 @@ typedef struct { } conn_data; typedef struct { - short closed; - int conn; /* reference to connection */ - int numcols; /* number of columns */ - int colnames, coltypes; /* reference to column information tables */ - MYSQL_RES *my_res; + short closed; + int conn; /* reference to connection */ + int numcols; /* number of columns */ + int colnames, coltypes; /* reference to column information tables */ + MYSQL_RES *my_res; + MYSQL_STMT *stmt; + MYSQL_BIND *params; /* bound to result columns */ + unsigned long *real_lengths; /* params[i].length will point to these real_lengths */ + bool *nulls; /* buffer for is_null */ + bool *errors; /* buffer for error */ } cur_data; +typedef union { + double number; + size_t size; + long long int longlong; + char c; +} column_data; + LUASQL_API int luaopen_luasql_mysql (lua_State *L); @@ -117,11 +130,22 @@ static cur_data *getcursor (lua_State *L) { /* ** Push the value of #i field of #tuple row. */ -static void pushvalue (lua_State *L, void *row, long int len) { - if (row == NULL) - lua_pushnil (L); - else - lua_pushlstring (L, row, len); +static void pushvalue (lua_State *L, cur_data *cur, int i) { + if (cur->nulls[i]) { + lua_pushnil(L); + cur->nulls[i] = 0; + } else { + /* error flags are set whenever lengths differ, but we resize only when real lengths are bigger */ + if (cur->errors[i]) { + if (cur->real_lengths[i] > cur->params[i].buffer_length) { + cur->params[i].buffer = realloc(cur->params[i].buffer, cur->real_lengths[i]); + cur->params[i].buffer_length = cur->real_lengths[i]; + } + mysql_stmt_fetch_column(cur->stmt, &cur->params[i], i, 0); + cur->errors[i] = 0; + } + lua_pushlstring(L, cur->params[i].buffer, cur->real_lengths[i]); + } } @@ -185,38 +209,40 @@ static void create_colinfo (lua_State *L, cur_data *cur) { ** Closes the cursos and nullify all structure fields. */ static void cur_nullify (lua_State *L, cur_data *cur) { + int i; /* Nullify structure fields. */ cur->closed = 1; mysql_free_result(cur->my_res); + mysql_stmt_close(cur->stmt); + for (i = 0; i < cur->numcols; i++) { + free(cur->params[i].buffer); + cur->params[i].buffer = NULL; + } luaL_unref (L, LUA_REGISTRYINDEX, cur->conn); luaL_unref (L, LUA_REGISTRYINDEX, cur->colnames); luaL_unref (L, LUA_REGISTRYINDEX, cur->coltypes); } - + /* ** Get another row of the given cursor. */ static int cur_fetch (lua_State *L) { cur_data *cur = getcursor (L); - MYSQL_RES *res = cur->my_res; - unsigned long *lengths; - MYSQL_ROW row = mysql_fetch_row(res); - if (row == NULL) { - cur_nullify (L, cur); + int r = mysql_stmt_fetch(cur->stmt); + if (r && r != MYSQL_DATA_TRUNCATED) { + cur_nullify(L, cur); lua_pushnil(L); /* no more results */ return 1; } - lengths = mysql_fetch_lengths(res); - if (lua_istable (L, 2)) { const char *opts = luaL_optstring (L, 3, "n"); if (strchr (opts, 'n') != NULL) { /* Copy values to numerical indices */ int i; for (i = 0; i < cur->numcols; i++) { - pushvalue (L, row[i], lengths[i]); - lua_rawseti (L, 2, i+1); + pushvalue(L, cur, i); + lua_rawseti(L, 2, i+1); } } if (strchr (opts, 'a') != NULL) { @@ -231,7 +257,7 @@ static int cur_fetch (lua_State *L) { lua_rawgeti(L, -1, i+1); /* push the field name */ /* Actually push the value */ - pushvalue (L, row[i], lengths[i]); + pushvalue (L, cur, i); lua_rawset (L, 2); } /* lua_pop(L, 1); Pops colnames table. Not needed */ @@ -243,7 +269,7 @@ static int cur_fetch (lua_State *L) { int i; luaL_checkstack (L, cur->numcols, LUASQL_PREFIX"too many columns"); for (i = 0; i < cur->numcols; i++) - pushvalue (L, row[i], lengths[i]); + pushvalue (L, cur, i); return cur->numcols; /* return #numcols values */ } } @@ -317,7 +343,7 @@ static int cur_getcoltypes (lua_State *L) { ** Push the number of rows. */ static int cur_numrows (lua_State *L) { - lua_pushinteger (L, (lua_Number)mysql_num_rows (getcursor(L)->my_res)); + lua_pushinteger (L, (lua_Number)mysql_stmt_num_rows (getcursor(L)->stmt)); return 1; } @@ -325,9 +351,12 @@ static int cur_numrows (lua_State *L) { /* ** Create a new Cursor object and push it on top of the stack. */ -static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) { - cur_data *cur = (cur_data *)lua_newuserdata(L, sizeof(cur_data)); +static int create_cursor (lua_State *L, int conn, MYSQL_STMT *stmt, MYSQL_RES *result, int cols) { + int i; + size_t memsize = sizeof (cur_data) + cols * (sizeof (MYSQL_BIND) + sizeof (unsigned long) + 2 * sizeof (bool)); + cur_data *cur = (cur_data *)lua_newuserdata(L, memsize); luasql_setmeta (L, LUASQL_CURSOR_MYSQL); + memset(cur, 0, memsize); /* fill in structure */ cur->closed = 0; @@ -336,8 +365,28 @@ static int create_cursor (lua_State *L, int conn, MYSQL_RES *result, int cols) { cur->colnames = LUA_NOREF; cur->coltypes = LUA_NOREF; cur->my_res = result; + cur->stmt = stmt; lua_pushvalue (L, conn); cur->conn = luaL_ref (L, LUA_REGISTRYINDEX); + cur->params = (MYSQL_BIND *)(sizeof (cur_data) + (char *)cur); /* after cur */ + cur->real_lengths = (unsigned long *)(cols * sizeof (MYSQL_BIND) + (char *)cur->params); + cur->nulls = (bool*)(cols * sizeof (unsigned long) + (char *)cur->real_lengths); + cur->errors = (bool*)(cols * sizeof (bool) + (char *)cur->nulls); + for (i = 0; i < cols; i++) { + cur->params[i].buffer_type = MYSQL_TYPE_STRING; + cur->params[i].buffer = malloc(0); + cur->params[i].buffer_length = 0; + cur->params[i].length = &cur->real_lengths[i]; + /* old versions use my_bool, newer use bool. There's no simple way to detect it */ + cur->params[i].is_null = (void*)&cur->nulls[i]; + cur->params[i].error = (void*)&cur->errors[i]; + } + + if (mysql_stmt_bind_result(cur->stmt, cur->params)) { + int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt)); + cur_nullify(L, cur); + return n; + } return 1; } @@ -417,27 +466,103 @@ static int conn_execute (lua_State *L) { conn_data *conn = getconnection (L); size_t st_len; const char *statement = luaL_checklstring (L, 2, &st_len); - if (mysql_real_query(conn->my_conn, statement, st_len)) - /* error executing query */ - return luasql_failmsg(L, "error executing query. MySQL: ", mysql_error(conn->my_conn)); - else - { - MYSQL_RES *res = mysql_store_result(conn->my_conn); - unsigned int num_cols = mysql_field_count(conn->my_conn); - - if (res) { /* tuples returned */ - return create_cursor (L, 1, res, num_cols); - } - else { /* mysql_use_result() returned nothing; should it have? */ - if(num_cols == 0) { /* no tuples returned */ - /* query does not return data (it was not a SELECT) */ - lua_pushinteger(L, mysql_affected_rows(conn->my_conn)); - return 1; - } - else /* mysql_use_result() should have returned data */ - return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_error(conn->my_conn)); + int i, nparams = lua_gettop(L); + MYSQL_STMT * stmt; + MYSQL_BIND * params; + column_data * params_data; + MYSQL_RES * res; + unsigned int num_cols; + + stmt = mysql_stmt_init(conn->my_conn); + if (stmt == NULL) + return luasql_failmsg(L, "error executing query (stmt_init). MySQL: ", mysql_error(conn->my_conn)); + if (mysql_stmt_prepare(stmt, statement, st_len)) { + int n = luasql_failmsg(L, "error executing query (stmt_prepare). MySQL: ", mysql_stmt_error(stmt)); + mysql_stmt_close(stmt); + return n; + } + if (nparams - 2 != mysql_stmt_param_count(stmt)) { + mysql_stmt_close(stmt); + return luasql_faildirect(L, "error executing query. Invalid parameter count"); + } + params = calloc(sizeof (MYSQL_BIND), nparams - 2); + params_data = calloc(sizeof (column_data), nparams - 2); + for (i = 3; i <= nparams; i++) { + switch (lua_type(L, i)) { + case LUA_TNIL: + params[i-3].buffer_type = MYSQL_TYPE_NULL; + break; + case LUA_TBOOLEAN: + params_data[i-3].c = lua_toboolean(L, i); + params[i-3].buffer_type = MYSQL_TYPE_TINY; + params[i-3].buffer = ¶ms_data[i-3].c; + params[i-3].buffer_length = sizeof (char); + break; + case LUA_TNUMBER: +#ifdef LUA_INT_TYPE + if (lua_isinteger(L, i)) { + params_data[i-3].longlong = lua_tointeger(L, i); + params[i-3].buffer_type = MYSQL_TYPE_LONGLONG; + params[i-3].buffer = ¶ms_data[i-3].longlong; + params[i-3].buffer_length = sizeof (long long int); + break; + } +#endif + params_data[i-3].number = lua_tonumber(L, i); + params[i-3].buffer_type = MYSQL_TYPE_DOUBLE; + params[i-3].buffer = ¶ms_data[i-3].number; + params[i-3].buffer_length = sizeof (double); + break; + case LUA_TSTRING: + params[i-3].buffer_type = MYSQL_TYPE_STRING; + params[i-3].buffer = (char*)lua_tolstring(L, i, ¶ms_data[i-3].size); + params[i-3].buffer_length = params_data[i-3].size; + params[i-3].length = ¶ms_data[i-3].size; + break; + default: + free(params); + free(params_data); + mysql_stmt_close(stmt); + return luasql_faildirect(L, "error executing query. Invalid parameter type"); } } + if (mysql_stmt_bind_param(stmt, params)) { + int n = luasql_failmsg(L, "error executing query (stmt_bind_param). MySQL: ", mysql_stmt_error(stmt)); + free(params); + free(params_data); + mysql_stmt_close(stmt); + return n; + } + if (mysql_stmt_execute(stmt)) { + int n = luasql_failmsg(L, "error executing query (stmt_execute). MySQL: ", mysql_stmt_error(stmt)); + free(params); + free(params_data); + mysql_stmt_close(stmt); + return n; + } + free(params); + free(params_data); + if (mysql_stmt_store_result(stmt)) { + int n = luasql_failmsg(L, "error executing query (stmt_store_result). MySQL: ", mysql_stmt_error(stmt)); + mysql_stmt_close(stmt); + return n; + } + + res = mysql_stmt_result_metadata(stmt); + num_cols = mysql_stmt_field_count(stmt); + if (res) { /* tuples returned */ + return create_cursor (L, 1, stmt, res, num_cols); + } + + if(num_cols == 0) { /* no tuples returned */ + /* query does not return data (it was not a SELECT) */ + lua_pushinteger(L, mysql_stmt_affected_rows(stmt)); + mysql_stmt_close(stmt); + return 1; + } else { /* mysql_use_result() should have returned data */ + mysql_stmt_close(stmt); + return luasql_failmsg(L, "error retrieving result. MySQL: ", mysql_stmt_error(stmt)); + } } diff --git a/src/ls_postgres.c b/src/ls_postgres.c index dd97ea7..85162dc 100644 --- a/src/ls_postgres.c +++ b/src/ls_postgres.c @@ -414,7 +414,20 @@ static int conn_escape (lua_State *L) { static int conn_execute (lua_State *L) { conn_data *conn = getconnection (L); const char *statement = luaL_checkstring (L, 2); - PGresult *res = PQexec(conn->pg_conn, statement); + int nparams = lua_gettop(L); + PGresult *res; + if (nparams > 2) { + int i; + const char ** values = malloc(sizeof (char *) * (nparams - 2)); + for (i = 3; i <= nparams; i++) + values[i - 3] = lua_tostring(L, i); + res = PQexecParams(conn->pg_conn, statement, nparams - 2, NULL, values, NULL, NULL, 0); + free(values); + } + else { + /* for multiple statements support */ + res = PQexec(conn->pg_conn, statement); + } if (res && PQresultStatus(res)==PGRES_COMMAND_OK) { /* no tuples returned */ lua_pushnumber(L, atof(PQcmdTuples(res))); diff --git a/src/ls_sqlite3.c b/src/ls_sqlite3.c index 33672ca..a9b9f2b 100644 --- a/src/ls_sqlite3.c +++ b/src/ls_sqlite3.c @@ -379,6 +379,7 @@ static int conn_escape(lua_State *L) */ static int conn_execute(lua_State *L) { + int i; conn_data *conn = getconnection(L); const char *statement = luaL_checkstring(L, 2); int res; @@ -398,6 +399,52 @@ static int conn_execute(lua_State *L) return luasql_faildirect(L, errmsg); } + /* bind any additional arguments to the statement */ + numcols = lua_gettop(L); + for (i = 3; i <= numcols; i++) + { + const char * buffer; + size_t size; + switch (lua_type(L, i)) { + case LUA_TNIL: + res = sqlite3_bind_null(vm, i - 2); + break; + + case LUA_TBOOLEAN: + case LUA_TNUMBER: +#ifdef LUA_INT_TYPE + if (lua_isnumber(L, i) && !lua_isinteger(L, i)) + { + res = sqlite3_bind_double(vm, i - 2, lua_tonumber(L, i)); + } + else + { + res = sqlite3_bind_int64(vm, i - 2, lua_tointeger(L, i)); + } +#else + res = sqlite3_bind_double(vm, i - 2, lua_tonumber(L, i)); +#endif + break; + + case LUA_TSTRING: + buffer = lua_tolstring(L, i, &size); + res = sqlite3_bind_text(vm, i - 2, buffer, size, SQLITE_TRANSIENT); + break; + + default: + sqlite3_finalize(vm); + return luaL_error(L, LUASQL_PREFIX"Invalid type for execute parameter %d", i - 2); + } + + /* handle errors */ + if (res != SQLITE_OK) + { + errmsg = sqlite3_errmsg(conn->sql_conn); + sqlite3_finalize(vm); + return luaL_error(L, LUASQL_PREFIX"Error binding parameter %d: %s", i - 2, errmsg); + } + } + /* process first result to retrive query information and type */ res = sqlite3_step(vm); numcols = sqlite3_column_count(vm); diff --git a/src/luasql.c b/src/luasql.c index ed1cae7..ecf463a 100644 --- a/src/luasql.c +++ b/src/luasql.c @@ -131,3 +131,35 @@ LUASQL_API void luasql_set_info (lua_State *L) { lua_pushliteral (L, "LuaSQL 2.3.5 (for "LUA_VERSION")"); lua_settable (L, -3); } + +/* +** Execute an SQL statement from a string. +** Return a Cursor object if the statement is a query, otherwise +** return the number of tuples affected by the statement. +** It's nothing more than a C implementation of: +** function conn:execute(sql, ...) +** local stmt, msg = conn:prepare(sql) +** if stmt == nil then return nil, msg end +** return stmt:execute(...) +** end +*/ +LUASQL_API int luasql_conn_execute (lua_State *L) { + // stack: conn sql ... + lua_getfield(L, 1, "prepare"); + lua_pushvalue(L, 1); + lua_pushvalue(L, 2); + // stack: conn sql ... conn.prepare conn sql + lua_call(L, 2, 2); + // stack: conn sql ... stmt msg + if (lua_isnil(L, -2)) + return 2; + lua_pop(L, 1); + lua_replace(L, 2); + // stack: conn stmt ... + lua_getfield(L, 2, "execute"); + lua_replace(L, 1); + // stack: stmt.execute stmt ... + lua_call(L, lua_gettop(L)-1, LUA_MULTRET); + // stack: cur msg? + return lua_gettop(L); +} diff --git a/src/luasql.h b/src/luasql.h index 345bf57..d7fe655 100644 --- a/src/luasql.h +++ b/src/luasql.h @@ -29,6 +29,7 @@ LUASQL_API int luasql_failmsg (lua_State *L, const char *err, const char *m); LUASQL_API int luasql_createmeta (lua_State *L, const char *name, const luaL_Reg *methods); LUASQL_API void luasql_setmeta (lua_State *L, const char *name); LUASQL_API void luasql_set_info (lua_State *L); +LUASQL_API int luasql_conn_execute (lua_State *L); #if !defined LUA_VERSION_NUM || LUA_VERSION_NUM==501 void luaL_setfuncs (lua_State *L, const luaL_Reg *l, int nup);