commit b7e4613559f0a741935ad375f07f9411c2107bb7 Author: Loren Segal Date: Fri Nov 6 22:07:00 2009 -0500 Add basic encoding aware logic. Convert all data to default external encoding diff --git a/extconf.rb b/extconf.rb index f31633e..9f3f57c 100644 --- a/extconf.rb +++ b/extconf.rb @@ -109,4 +108,6 @@ File.open('error_const.h', 'w') do |f| end end +$CPPFLAGS += " -DRUBY19" if RUBY_VERSION =~ /1.9/ + create_makefile("mysql") diff --git a/mysql.c b/mysql.c index 9c4515e..91570af 100644 --- a/mysql.c +++ b/mysql.c @@ -16,6 +16,32 @@ #define rb_str_set_len(str, length) (RSTRING_LEN(str) = (length)) #endif +#ifdef RUBY19 +#include +#define DEFAULT_ENCODING (rb_enc_get(rb_enc_default_external())) +#else +#define DEFAULT_ENCODING NULL +#define rb_enc_str_new(ptr, len, enc) rb_str_new(ptr, len) +#endif + +VALUE +rb_enc_tainted_str_new(const char *ptr, long len) +{ + VALUE str = rb_enc_str_new(ptr, len, DEFAULT_ENCODING); + + OBJ_TAINT(str); + return str; +} + +VALUE +rb_enc_tainted_str_new2(const char *ptr) +{ + VALUE str = rb_enc_str_new(ptr, strlen(ptr), DEFAULT_ENCODING); + + OBJ_TAINT(str); + return str; +} + #ifdef HAVE_MYSQL_H #include #include @@ -180,7 +206,7 @@ static void mysql_raise(MYSQL* m) VALUE e = rb_exc_new2(eMysql, mysql_error(m)); rb_iv_set(e, "errno", INT2FIX(mysql_errno(m))); #if MYSQL_VERSION_ID >= 40101 - rb_iv_set(e, "sqlstate", rb_tainted_str_new2(mysql_sqlstate(m))); + rb_iv_set(e, "sqlstate", rb_enc_tainted_str_new2(mysql_sqlstate(m))); #endif rb_exc_raise(e); } @@ -207,9 +233,9 @@ static VALUE make_field_obj(MYSQL_FIELD* f) if (f == NULL) return Qnil; obj = rb_obj_alloc(cMysqlField); - rb_iv_set(obj, "name", f->name? rb_str_freeze(rb_tainted_str_new2(f->name)): Qnil); - rb_iv_set(obj, "table", f->table? rb_str_freeze(rb_tainted_str_new2(f->table)): Qnil); - rb_iv_set(obj, "def", f->def? rb_str_freeze(rb_tainted_str_new2(f->def)): Qnil); + rb_iv_set(obj, "name", f->name? rb_str_freeze(rb_enc_tainted_str_new2(f->name)): Qnil); + rb_iv_set(obj, "table", f->table? rb_str_freeze(rb_enc_tainted_str_new2(f->table)): Qnil); + rb_iv_set(obj, "def", f->def? rb_str_freeze(rb_enc_tainted_str_new2(f->def)): Qnil); rb_iv_set(obj, "type", INT2NUM(f->type)); rb_iv_set(obj, "length", INT2NUM(f->length)); rb_iv_set(obj, "max_length", INT2NUM(f->max_length)); @@ -290,7 +316,7 @@ static VALUE escape_string(VALUE klass, VALUE str) { VALUE ret; Check_Type(str, T_STRING); - ret = rb_str_new(0, (RSTRING_LEN(str))*2+1); + ret = rb_enc_str_new(0, (RSTRING_LEN(str))*2+1, DEFAULT_ENCODING); rb_str_set_len(ret, mysql_escape_string(RSTRING_PTR(ret), RSTRING_PTR(str), RSTRING_LEN(str))); return ret; } @@ -298,7 +324,7 @@ static VALUE escape_string(VALUE klass, VALUE str) /* client_info() */ static VALUE client_info(VALUE klass) { - return rb_tainted_str_new2(mysql_get_client_info()); + return rb_enc_tainted_str_new2(mysql_get_client_info()); } #if MYSQL_VERSION_ID >= 32332 @@ -428,7 +454,7 @@ static VALUE real_escape_string(VALUE obj, VALUE str) MYSQL* m = GetHandler(obj); VALUE ret; Check_Type(str, T_STRING); - ret = rb_str_new(0, (RSTRING_LEN(str))*2+1); + ret = rb_enc_str_new(0, (RSTRING_LEN(str))*2+1, DEFAULT_ENCODING); rb_str_set_len(ret, mysql_real_escape_string(m, RSTRING_PTR(ret), RSTRING_PTR(str), RSTRING_LEN(str))); return ret; } @@ -467,7 +493,7 @@ static VALUE change_user(int argc, VALUE* argv, VALUE obj) /* character_set_name() */ static VALUE character_set_name(VALUE obj) { - return rb_tainted_str_new2(mysql_character_set_name(GetHandler(obj))); + return rb_enc_tainted_str_new2(mysql_character_set_name(GetHandler(obj))); } #endif @@ -532,7 +558,7 @@ static VALUE field_count(VALUE obj) /* host_info() */ static VALUE host_info(VALUE obj) { - return rb_tainted_str_new2(mysql_get_host_info(GetHandler(obj))); + return rb_enc_tainted_str_new2(mysql_get_host_info(GetHandler(obj))); } /* proto_info() */ @@ -544,14 +570,14 @@ static VALUE proto_info(VALUE obj) /* server_info() */ static VALUE server_info(VALUE obj) { - return rb_tainted_str_new2(mysql_get_server_info(GetHandler(obj))); + return rb_enc_tainted_str_new2(mysql_get_server_info(GetHandler(obj))); } /* info() */ static VALUE info(VALUE obj) { const char* p = mysql_info(GetHandler(obj)); - return p? rb_tainted_str_new2(p): Qnil; + return p? rb_enc_tainted_str_new2(p): Qnil; } /* insert_id() */ @@ -586,7 +612,7 @@ static VALUE list_dbs(int argc, VALUE* argv, VALUE obj) n = mysql_num_rows(res); ret = rb_ary_new2(n); for (i=0; i", RSTRING_PTR(n)); return s; } @@ -1255,7 +1281,7 @@ static void mysql_stmt_raise(MYSQL_STMT* s) { VALUE e = rb_exc_new2(eMysql, mysql_stmt_error(s)); rb_iv_set(e, "errno", INT2FIX(mysql_stmt_errno(s))); - rb_iv_set(e, "sqlstate", rb_tainted_str_new2(mysql_stmt_sqlstate(s))); + rb_iv_set(e, "sqlstate", rb_enc_tainted_str_new2(mysql_stmt_sqlstate(s))); rb_exc_raise(e); } @@ -1571,7 +1597,7 @@ static VALUE stmt_fetch(VALUE obj) case MYSQL_TYPE_NEWDECIMAL: case MYSQL_TYPE_BIT: #endif - v = rb_tainted_str_new(s->result.bind[i].buffer, s->result.length[i]); + v = rb_enc_tainted_str_new(s->result.bind[i].buffer, s->result.length[i]); break; default: rb_raise(rb_eTypeError, "unknown buffer_type: %d", s->result.bind[i].buffer_type); @@ -1760,7 +1786,7 @@ static VALUE stmt_send_long_data(VALUE obj, VALUE col, VALUE data) static VALUE stmt_sqlstate(VALUE obj) { struct mysql_stmt* s = DATA_PTR(obj); - return rb_tainted_str_new2(mysql_stmt_sqlstate(s->stmt)); + return rb_enc_tainted_str_new2(mysql_stmt_sqlstate(s->stmt)); } /*------------------------------- diff --git a/test.rb b/test.rb index 92151e0..f11939c 100644 --- a/test.rb +++ b/test.rb @@ -1196,7 +1196,9 @@ class TC_MysqlStmt2 < Test::Unit::TestCase @s.execute assert_equal([nil], @s.fetch) assert_equal([""], @s.fetch) - assert_equal(["abc"], @s.fetch) + row = @s.fetch + assert_equal(Encoding.default_external, row[0].encoding) if RUBY_VERSION =~ /1.9/ + assert_equal(["abc"], row) assert_equal(["def"], @s.fetch) assert_equal(["abc"], @s.fetch) assert_equal(["def"], @s.fetch) @@ -1233,6 +1235,7 @@ class TC_MysqlStmt2 < Test::Unit::TestCase case c when 0 assert_equal([1,"abc",Mysql::Time.new(1970,12,24,23,59,05)], a) + assert_equal(Encoding.default_external, a[1].encoding) if RUBY_VERSION =~ /1.9/ when 1 assert_equal([2,"def",Mysql::Time.new(2112,9,3,12,34,56)], a) when 2