diff --git a/enginetest/queries/variable_queries.go b/enginetest/queries/variable_queries.go index f530e216e3..de81fea119 100644 --- a/enginetest/queries/variable_queries.go +++ b/enginetest/queries/variable_queries.go @@ -22,12 +22,6 @@ import ( ) var VariableQueries = []ScriptTest{ - { - Name: "use string name for foreign_key checks", - SetUpScript: []string{}, - Query: "select @@GLOBAL.unknown", - ExpectedErr: sql.ErrUnknownSystemVariable, - }, { Name: "use string name for foreign_key checks", SetUpScript: []string{}, @@ -649,6 +643,10 @@ var VariableQueries = []ScriptTest{ } var VariableErrorTests = []QueryErrorTest{ + { + Query: "select @@GLOBAL.unknown", + ExpectedErr: sql.ErrUnknownSystemVariable, + }, { Query: "set @@does_not_exist = 100", ExpectedErr: sql.ErrUnknownSystemVariable, diff --git a/enginetest/server_engine_test.go b/enginetest/server_engine_test.go index c95f60b1c3..eb3263826b 100644 --- a/enginetest/server_engine_test.go +++ b/enginetest/server_engine_test.go @@ -6,6 +6,7 @@ import ( "fmt" "math" "net" + "os" "testing" "github.com/dolthub/vitess/go/mysql" @@ -375,3 +376,93 @@ func TestServerPreparedStatements(t *testing.T) { }) } } + +func TestServerVariables(t *testing.T) { + hostname, herr := os.Hostname() + require.NoError(t, herr) + + port, perr := findEmptyPort() + require.NoError(t, perr) + + s, serr := initTestServer(port) + require.NoError(t, serr) + + go s.Start() + defer s.Close() + + tests := []serverScriptTest{ + { + name: "test that config system variables are properly set", + setup: []string{}, + assertions: []serverScriptTestAssertion{ + { + query: "select @@hostname, @@port, @@max_connections, @@net_read_timeout, @@net_write_timeout", + isExec: false, + expectedRows: []any{ + sql.Row{hostname, port, 1, 1, 1}, + }, + checkRows: func(t *testing.T, rows *gosql.Rows, expectedRows []any) (bool, error) { + var resHostname string + var resPort int + var resMaxConnections int + var resNetReadTimeout int + var resNetWriteTimeout int + var rowNum int + for rows.Next() { + if err := rows.Scan(&resHostname, &resPort, &resMaxConnections, &resNetReadTimeout, &resNetWriteTimeout); err != nil { + return false, err + } + if rowNum >= len(expectedRows) { + return false, nil + } + expectedRow := expectedRows[rowNum].(sql.Row) + require.Equal(t, expectedRow[0].(string), resHostname) + require.Equal(t, expectedRow[1].(int), resPort) + } + return true, nil + }, + }, + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + conn, cerr := dbr.Open("mysql", fmt.Sprintf(noUserFmt, address, port), nil) + require.NoError(t, cerr) + defer conn.Close() + commonSetup := []string{ + "create database test_db;", + "use test_db;", + } + commonTeardown := []string{ + "drop database test_db", + } + for _, stmt := range append(commonSetup, test.setup...) { + _, err := conn.Exec(stmt) + require.NoError(t, err) + } + for _, assertion := range test.assertions { + t.Run(assertion.query, func(t *testing.T) { + if assertion.skip { + t.Skip() + } + rows, err := conn.Query(assertion.query, assertion.args...) + if assertion.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + + ok, err := assertion.checkRows(t, rows, assertion.expectedRows) + require.NoError(t, err) + require.True(t, ok) + }) + } + for _, stmt := range append(commonTeardown) { + _, err := conn.Exec(stmt) + require.NoError(t, err) + } + }) + } +} diff --git a/server/server.go b/server/server.go index 205147c1bb..35097b9cec 100644 --- a/server/server.go +++ b/server/server.go @@ -18,6 +18,7 @@ import ( "errors" "fmt" "net" + "strconv" "time" "github.com/dolthub/vitess/go/mysql" @@ -118,15 +119,47 @@ func portInUse(hostPort string) bool { return false } +func getPortOrDefault(cfg mysql.ListenerConfig) int64 { + // TODO read this value from systemVars + defaultPort := int64(3606) + _, port, err := net.SplitHostPort(cfg.Listener.Addr().String()) + if err != nil { + return defaultPort + } + portInt, err := strconv.ParseInt(port, 10, 64) + if err != nil { + return defaultPort + } + return portInt +} + +func updateSystemVariables(cfg mysql.ListenerConfig) error { + port := getPortOrDefault(cfg) + + // TODO: add the rest of the config variables + err := sql.SystemVariables.AssignValues(map[string]interface{}{ + "port": port, + "max_connections": cfg.MaxConns, + "net_read_timeout": cfg.ConnReadTimeout.Seconds(), + "net_write_timeout": cfg.ConnWriteTimeout.Seconds(), + }) + if err != nil { + return err + } + return nil +} + func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handler mysql.Handler, sel ServerEventListener) (*Server, error) { - if cfg.ConnReadTimeout < 0 { - cfg.ConnReadTimeout = 0 + oneSecond := time.Duration(1) * time.Second + // TODO read default values from systemVars + if cfg.ConnReadTimeout < oneSecond { + cfg.ConnReadTimeout = oneSecond * 30 } - if cfg.ConnWriteTimeout < 0 { - cfg.ConnWriteTimeout = 0 + if cfg.ConnWriteTimeout < oneSecond { + cfg.ConnWriteTimeout = oneSecond * 60 } - if cfg.MaxConnections < 0 { - cfg.MaxConnections = 0 + if cfg.MaxConnections < 1 { + cfg.MaxConnections = 151 } for _, opt := range cfg.Options { @@ -172,6 +205,11 @@ func newServerFromHandler(cfg Config, e *sqle.Engine, sm *SessionManager, handle return nil, err } + err = updateSystemVariables(listenerCfg) + if err != nil { + return nil, err + } + return &Server{ Listener: protocolListener, handler: handler, diff --git a/server/server_config.go b/server/server_config.go index 853d622616..d0af2c26ac 100644 --- a/server/server_config.go +++ b/server/server_config.go @@ -109,14 +109,14 @@ func (c Config) NewConfig() (Config, error) { if !ok { return Config{}, sql.ErrUnknownSystemVariable.New("net_write_timeout") } - c.ConnWriteTimeout = time.Duration(timeout) * time.Millisecond + c.ConnWriteTimeout = time.Duration(timeout) * time.Second } if _, val, ok := sql.SystemVariables.GetGlobal("net_read_timeout"); ok { timeout, ok := val.(int64) if !ok { return Config{}, sql.ErrUnknownSystemVariable.New("net_read_timeout") } - c.ConnReadTimeout = time.Duration(timeout) * time.Millisecond + c.ConnReadTimeout = time.Duration(timeout) * time.Second } return c, nil } diff --git a/server/server_config_test.go b/server/server_config_test.go index c1d7b2b2dc..a2a8319a20 100644 --- a/server/server_config_test.go +++ b/server/server_config_test.go @@ -49,14 +49,14 @@ func TestConfigWithDefaults(t *testing.T) { Type: types.NewSystemIntType("net_write_timeout", 1, 9223372036854775807, false), ConfigField: "ConnWriteTimeout", Default: int64(76), - ExpectedCmp: int64(76000000), + ExpectedCmp: int64(76000000000), }, { Name: "net_read_timeout", Scope: sql.SystemVariableScope_Both, Type: types.NewSystemIntType("net_read_timeout", 1, 9223372036854775807, false), ConfigField: "ConnReadTimeout", Default: int64(67), - ExpectedCmp: int64(67000000), + ExpectedCmp: int64(67000000000), }, } diff --git a/server/server_test.go b/server/server_test.go index 0929f5e6de..3d229789ae 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -18,8 +18,8 @@ import ( gsql "github.com/dolthub/go-mysql-server/sql" ) -// TestSeverCustomListener verifies a caller can provide their own net.Conn implementation for the server to use -func TestSeverCustomListener(t *testing.T) { +// TestServerCustomListener verifies a caller can provide their own net.Conn implementation for the server to use +func TestServerCustomListener(t *testing.T) { dbName := "mydb" // create a net.Conn thats based on a golang buffer buffer := 1024 diff --git a/sql/variables/system_variables.go b/sql/variables/system_variables.go index 188d3bc5ec..6cf1431d9f 100644 --- a/sql/variables/system_variables.go +++ b/sql/variables/system_variables.go @@ -17,6 +17,7 @@ package variables import ( "fmt" "math" + "os" "strings" "sync" "time" @@ -187,6 +188,11 @@ func init() { InitSystemVariables() } +func getHostname() string { + hostname, _ := os.Hostname() + return hostname +} + // systemVars is the internal collection of all MySQL system variables according to the following pages: // https://dev.mysql.com/doc/refman/8.0/en/server-system-variables.html // https://dev.mysql.com/doc/refman/8.0/en/replication-options-gtids.html @@ -1009,7 +1015,7 @@ var systemVars = map[string]sql.SystemVariable{ Dynamic: false, SetVarHintApplies: false, Type: types.NewSystemStringType("hostname"), - Default: "", + Default: getHostname(), }, "immediate_server_version": &sql.MysqlSystemVariable{ Name: "immediate_server_version",