diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index e0de7b124..1a71cf633 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -643,39 +643,26 @@ static VALUE rb_mysql_client_abandon_results(VALUE self) { * Query the database with +sql+, with optional +options+. For the possible * options, see @@default_query_options on the Mysql2::Client class. */ -static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { +static VALUE rb_query(VALUE self, VALUE sql, VALUE current) { #ifndef _WIN32 struct async_query_args async_args; #endif struct nogvl_send_query_args args; - int async = 0; - VALUE opts, current; -#ifdef HAVE_RUBY_ENCODING_H - rb_encoding *conn_enc; -#endif GET_CLIENT(self); REQUIRE_CONNECTED(wrapper); args.mysql = wrapper->client; - current = rb_hash_dup(rb_iv_get(self, "@query_options")); RB_GC_GUARD(current); Check_Type(current, T_HASH); rb_iv_set(self, "@current_query_options", current); - if (rb_scan_args(argc, argv, "11", &args.sql, &opts) == 2) { - rb_funcall(current, intern_merge_bang, 1, opts); - - if (rb_hash_aref(current, sym_async) == Qtrue) { - async = 1; - } - } - - Check_Type(args.sql, T_STRING); + Check_Type(sql, T_STRING); #ifdef HAVE_RUBY_ENCODING_H - conn_enc = rb_to_encoding(wrapper->encoding); /* ensure the string is in the encoding the connection is expecting */ - args.sql = rb_str_export_to_enc(args.sql, conn_enc); + args.sql = rb_str_export_to_enc(sql, rb_to_encoding(wrapper->encoding)); +#else + args.sql = sql; #endif args.sql_ptr = StringValuePtr(args.sql); args.sql_len = RSTRING_LEN(args.sql); @@ -686,15 +673,15 @@ static VALUE rb_mysql_client_query(int argc, VALUE * argv, VALUE self) { #ifndef _WIN32 rb_rescue2(do_send_query, (VALUE)&args, disconnect_and_raise, self, rb_eException, (VALUE)0); - if (!async) { + if (rb_hash_aref(current, sym_async) == Qtrue) { + return Qnil; + } else { async_args.fd = wrapper->client->net.fd; async_args.self = self; rb_rescue2(do_query, (VALUE)&async_args, disconnect_and_raise, self, rb_eException, (VALUE)0); return rb_mysql_client_async_result(self); - } else { - return Qnil; } #else do_send_query(&args); @@ -1262,7 +1249,6 @@ void init_mysql2_client() { rb_define_singleton_method(cMysql2Client, "escape", rb_mysql_client_escape, 1); rb_define_method(cMysql2Client, "close", rb_mysql_client_close, 0); - rb_define_method(cMysql2Client, "query", rb_mysql_client_query, -1); rb_define_method(cMysql2Client, "abandon_results!", rb_mysql_client_abandon_results, 0); rb_define_method(cMysql2Client, "escape", rb_mysql_client_real_escape, 1); rb_define_method(cMysql2Client, "info", rb_mysql_client_info, 0); @@ -1297,6 +1283,7 @@ void init_mysql2_client() { rb_define_private_method(cMysql2Client, "ssl_set", set_ssl_options, 5); rb_define_private_method(cMysql2Client, "initialize_ext", initialize_ext, 0); rb_define_private_method(cMysql2Client, "connect", rb_connect, 7); + rb_define_private_method(cMysql2Client, "_query", rb_query, 2); sym_id = ID2SYM(rb_intern("id")); sym_version = ID2SYM(rb_intern("version")); diff --git a/lib/mysql2/client.rb b/lib/mysql2/client.rb index 99a043f9d..764e410ee 100644 --- a/lib/mysql2/client.rb +++ b/lib/mysql2/client.rb @@ -74,6 +74,18 @@ def self.default_query_options @@default_query_options end + if Thread.respond_to?(:handle_interrupt) + def query(sql, options = {}) + Thread.handle_interrupt(Timeout::ExitException => :never) do + _query(sql, @query_options.merge(options)) + end + end + else + def query(sql, options = {}) + _query(sql, @query_options.merge(options)) + end + end + def query_info info = query_info_string return {} unless info diff --git a/spec/mysql2/client_spec.rb b/spec/mysql2/client_spec.rb index 7a0737c52..fec8b0712 100644 --- a/spec/mysql2/client_spec.rb +++ b/spec/mysql2/client_spec.rb @@ -461,59 +461,40 @@ def connect *args }.should raise_error(Mysql2::Error) end - it "should close the connection when an exception is raised" do - begin - Timeout.timeout(1, Timeout::Error) do - @client.query("SELECT sleep(2)") - end - rescue Timeout::Error - end - lambda { - @client.query("SELECT 1") - }.should raise_error(Mysql2::Error, 'closed MySQL connection') + it 'should be impervious to connection-corrupting timeouts ' do + pending('`Thread.handle_interrupt` is not defined') unless Thread.respond_to?(:handle_interrupt) + # attempt to break the connection + expect { Timeout.timeout(0.1) { @client.query('SELECT SLEEP(1)') } }.to raise_error(Timeout::Error) + + # expect the connection to not be broken + expect { @client.query('SELECT 1') }.to_not raise_error end - it "should handle Timeouts without leaving the connection hanging if reconnect is true" do - client = Mysql2::Client.new(DatabaseCredentials['root'].merge(:reconnect => true)) - begin - Timeout.timeout(1, Timeout::Error) do - client.query("SELECT sleep(2)") - end - rescue Timeout::Error + context 'when a non-standard exception class is raised' do + it "should close the connection when an exception is raised" do + expect { Timeout.timeout(0.1, ArgumentError) { @client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError) + expect { @client.query('SELECT 1') }.to raise_error(Mysql2::Error, 'closed MySQL connection') end - lambda { - client.query("SELECT 1") - }.should_not raise_error(Mysql2::Error) - end + it "should handle Timeouts without leaving the connection hanging if reconnect is true" do + client = Mysql2::Client.new(DatabaseCredentials['root'].merge(:reconnect => true)) - it "should handle Timeouts without leaving the connection hanging if reconnect is set to true after construction true" do - client = Mysql2::Client.new(DatabaseCredentials['root']) - begin - Timeout.timeout(1, Timeout::Error) do - client.query("SELECT sleep(2)") - end - rescue Timeout::Error + expect { Timeout.timeout(0.1, ArgumentError) { client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError) + expect { client.query('SELECT 1') }.to_not raise_error end - lambda { - client.query("SELECT 1") - }.should raise_error(Mysql2::Error) + it "should handle Timeouts without leaving the connection hanging if reconnect is set to true after construction true" do + client = Mysql2::Client.new(DatabaseCredentials['root']) - client.reconnect = true + expect { Timeout.timeout(0.1, ArgumentError) { client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError) + expect { client.query('SELECT 1') }.to raise_error(Mysql2::Error) - begin - Timeout.timeout(1, Timeout::Error) do - client.query("SELECT sleep(2)") - end - rescue Timeout::Error - end - - lambda { - client.query("SELECT 1") - }.should_not raise_error(Mysql2::Error) + client.reconnect = true + expect { Timeout.timeout(0.1, ArgumentError) { client.query('SELECT SLEEP(1)') } }.to raise_error(ArgumentError) + expect { client.query('SELECT 1') }.to_not raise_error + end end it "threaded queries should be supported" do