diff --git a/sqlite3.go b/sqlite3.go old mode 100644 new mode 100755 index 3025a500..78864ddf --- a/sqlite3.go +++ b/sqlite3.go @@ -2282,3 +2282,28 @@ func (rc *SQLiteRows) nextSyncLocked(dest []driver.Value) error { } return nil } + +// SQLiteConnector implements driver.Connector for custom connection handling. +type SQLiteConnector struct { + DSN string + DriverInstance *SQLiteDriver +} + +// Connect implements driver.Connector. +func (c *SQLiteConnector) Connect(ctx context.Context) (driver.Conn, error) { + // Context is ignored for now, as SQLiteDriver.Open does not use it. + return c.DriverInstance.Open(c.DSN) +} + +// Driver returns the underlying driver. +func (c *SQLiteConnector) Driver() driver.Driver { + return c.DriverInstance +} + +// NewConnector returns a new SQLiteConnector. +func NewConnector(dsn string) *SQLiteConnector { + return &SQLiteConnector{ + DSN: dsn, + DriverInstance: &SQLiteDriver{}, + } +} diff --git a/sqlite3_test.go b/sqlite3_test.go old mode 100644 new mode 100755 index 94de7386..9a1f6b9e --- a/sqlite3_test.go +++ b/sqlite3_test.go @@ -10,6 +10,7 @@ package sqlite3 import ( "bytes" + "context" "database/sql" "database/sql/driver" "errors" @@ -2586,3 +2587,57 @@ func benchmarkQueryParallel(b *testing.B) { } }) } + +func TestSQLiteConnector(t *testing.T) { + connector := NewConnector(":memory:") + db := sql.OpenDB(connector) + defer db.Close() + + ctx := context.Background() + if err := db.PingContext(ctx); err != nil { + t.Fatalf("PingContext failed: %v", err) + } + + // Test basic query + _, err := db.ExecContext(ctx, "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + if err != nil { + t.Fatalf("CREATE TABLE failed: %v", err) + } + + _, err = db.ExecContext(ctx, "INSERT INTO test (name) VALUES (?)", "Alice") + if err != nil { + t.Fatalf("INSERT failed: %v", err) + } + + var name string + err = db.QueryRowContext(ctx, "SELECT name FROM test WHERE id = 1").Scan(&name) + if err != nil { + t.Fatalf("SELECT failed: %v", err) + } + if name != "Alice" { + t.Errorf("Expected name 'Alice', got '%s'", name) + } +} + +func TestSQLiteConnectorContextCancellation(t *testing.T) { + connector := NewConnector(":memory:") + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + conn, err := connector.Connect(ctx) + if err == nil { + conn.Close() + t.Error("Expected error on canceled context, got nil") + } +} + +func TestSQLiteConnectorDriver(t *testing.T) { + driver := &SQLiteDriver{} + connector := &SQLiteConnector{ + DSN: ":memory:", + DriverInstance: driver, + } + if connector.Driver() != driver { + t.Errorf("Driver() returned %v, want %v", connector.Driver(), driver) + } +}