diff --git a/internal/cmd/database/dump.go b/internal/cmd/database/dump.go index d6aa5a1f..8e40cb36 100644 --- a/internal/cmd/database/dump.go +++ b/internal/cmd/database/dump.go @@ -254,13 +254,13 @@ func dump(ch *cmdutil.Helper, cmd *cobra.Command, flags *dumpFlags, args []strin if flags.shard != "" { if flags.replica { - useCmd := fmt.Sprintf("USE `%s/%s@replica`;", dbName, flags.shard) + useCmd := shardUseCommand(dbName, flags.shard, flags.replica, flags.rdonly) cfg.SessionVars = append([]string{useCmd}, cfg.SessionVars...) } else if flags.rdonly { - useCmd := fmt.Sprintf("USE `%s/%s@rdonly`;", dbName, flags.shard) + useCmd := shardUseCommand(dbName, flags.shard, flags.replica, flags.rdonly) cfg.SessionVars = append([]string{useCmd}, cfg.SessionVars...) } else { - useCmd := fmt.Sprintf("USE `%s/%s`;", dbName, flags.shard) + useCmd := shardUseCommand(dbName, flags.shard, flags.replica, flags.rdonly) cfg.SessionVars = append([]string{useCmd}, cfg.SessionVars...) } } @@ -392,3 +392,17 @@ func parseColumnIncludes(columns []string) (map[string]map[string]bool, error) { return result, nil } + +func shardUseCommand(dbName string, shard string, replica bool, rdonly bool) string { + target := fmt.Sprintf("%s/%s", dbName, shard) + if replica { + target += "@replica" + } else if rdonly { + target += "@rdonly" + } + return fmt.Sprintf("USE %s;", quoteIdentifier(target)) +} + +func quoteIdentifier(identifier string) string { + return "`" + strings.ReplaceAll(identifier, "`", "``") + "`" +} diff --git a/internal/cmd/database/dump_test.go b/internal/cmd/database/dump_test.go index e4ae9079..2bc0ff49 100644 --- a/internal/cmd/database/dump_test.go +++ b/internal/cmd/database/dump_test.go @@ -100,3 +100,12 @@ func TestParseColumnIncludes(t *testing.T) { }) } } + +func TestShardUseCommand(t *testing.T) { + c := qt.New(t) + + c.Assert(shardUseCommand("commerce", "-80", false, false), qt.Equals, "USE `commerce/-80`;") + c.Assert(shardUseCommand("commerce", "-80", true, false), qt.Equals, "USE `commerce/-80@replica`;") + c.Assert(shardUseCommand("commerce", "-80", false, true), qt.Equals, "USE `commerce/-80@rdonly`;") + c.Assert(shardUseCommand("key`space", "sh`ard", false, false), qt.Equals, "USE `key``space/sh``ard`;") +} diff --git a/internal/dumper/dumper.go b/internal/dumper/dumper.go index 273d9eaa..45aeb779 100644 --- a/internal/dumper/dumper.go +++ b/internal/dumper/dumper.go @@ -249,7 +249,7 @@ func writeMetaData(outdir string) error { } func (d *Dumper) dumpTableSchema(conn *Connection, database string, table string, views map[string]bool) error { - qr, err := conn.Fetch(fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", database, table)) + qr, err := conn.Fetch(fmt.Sprintf("SHOW CREATE TABLE %s.%s", quoteIdentifier(database), quoteIdentifier(table))) if err != nil { return err } @@ -293,7 +293,7 @@ func (d *Dumper) dumpTable(ctx context.Context, conn *Connection, database strin return err } - cursor, err := conn.StreamFetch(fmt.Sprintf("SELECT %s FROM `%s`.`%s` %s", strings.Join(dumpCtx.selfields, ", "), database, table, dumpCtx.where)) + cursor, err := conn.StreamFetch(fmt.Sprintf("SELECT %s FROM %s.%s %s", strings.Join(dumpCtx.selfields, ", "), quoteIdentifier(database), quoteIdentifier(table), dumpCtx.where)) if err != nil { return err } @@ -387,9 +387,9 @@ func (d *Dumper) tableDumpContext(conn *Connection, table string) (*dumpContext, ctx.fieldNames = append(ctx.fieldNames, name) replacement, ok := d.cfg.Selects[table][name] if ok { - ctx.selfields = append(ctx.selfields, fmt.Sprintf("%s AS `%s`", replacement, name)) + ctx.selfields = append(ctx.selfields, fmt.Sprintf("%s AS %s", replacement, quoteIdentifier(name))) } else { - ctx.selfields = append(ctx.selfields, fmt.Sprintf("`%s`", name)) + ctx.selfields = append(ctx.selfields, quoteIdentifier(name)) } } @@ -406,7 +406,7 @@ func (d *Dumper) tableDumpContext(conn *Connection, table string) (*dumpContext, } func (d *Dumper) allTables(conn *Connection, database string) ([]string, error) { - qr, err := conn.Fetch(fmt.Sprintf("SHOW TABLES FROM `%s`", database)) + qr, err := conn.Fetch(fmt.Sprintf("SHOW TABLES FROM %s", quoteIdentifier(database))) if err != nil { return nil, err } @@ -419,12 +419,12 @@ func (d *Dumper) allTables(conn *Connection, database string) ([]string, error) } func (d *Dumper) allViews(conn *Connection, database string) (map[string]bool, error) { - query := `SELECT TABLE_NAME - FROM information_schema.TABLES - WHERE TABLE_SCHEMA LIKE '%s' - AND TABLE_TYPE = 'VIEW' - ` - qr, err := conn.Fetch(fmt.Sprintf(query, database)) + query := "SELECT TABLE_NAME \n" + + "\t\t\t FROM information_schema.TABLES \n" + + "\t\t\t WHERE TABLE_SCHEMA LIKE %s \n" + + "\t\t\t AND TABLE_TYPE = 'VIEW'\n" + + "\t\t\t" + qr, err := conn.Fetch(fmt.Sprintf(query, quoteStringLiteral(database))) if err != nil { return nil, err } @@ -466,7 +466,7 @@ func (d *Dumper) filterDatabases(conn *Connection, filter *regexp.Regexp, invert // dumpableFieldNames returns a slice that contains valid field names for the dump. func (d *Dumper) dumpableFieldNames(conn *Connection, table string) ([]string, error) { - qr, err := conn.Fetch(fmt.Sprintf("SHOW FIELDS FROM `%s`", table)) + qr, err := conn.Fetch(fmt.Sprintf("SHOW FIELDS FROM %s", quoteIdentifier(table))) if err != nil { return nil, err } @@ -515,6 +515,14 @@ func writeFile(file string, data string) error { return nil } +func quoteIdentifier(identifier string) string { + return "`" + strings.ReplaceAll(identifier, "`", "``") + "`" +} + +func quoteStringLiteral(s string) string { + return "'" + strings.ReplaceAll(s, "'", "''") + "'" +} + // escapeBytes used to escape the literal byte. // See https://dev.mysql.com/doc/refman/5.7/en/string-literals.html // for more information on how to escape string literals in MySQL. diff --git a/internal/dumper/dumper_test.go b/internal/dumper/dumper_test.go index 5ae7d2f3..41b03c77 100644 --- a/internal/dumper/dumper_test.go +++ b/internal/dumper/dumper_test.go @@ -1574,6 +1574,124 @@ func TestEscapeBytes(t *testing.T) { } } +func TestQuoteIdentifier(t *testing.T) { + c := qt.New(t) + + c.Assert(quoteIdentifier("simple"), qt.Equals, "`simple`") + c.Assert(quoteIdentifier("customer_report`$probe"), qt.Equals, "`customer_report``$probe`") + c.Assert(quoteIdentifier("display`name"), qt.Equals, "`display``name`") +} + +func TestQuoteStringLiteral(t *testing.T) { + c := qt.New(t) + + c.Assert(quoteStringLiteral("simple"), qt.Equals, "'simple'") + c.Assert(quoteStringLiteral("test'db"), qt.Equals, "'test''db'") +} + +func TestDumperEscapesDiscoveredIdentifiers(t *testing.T) { + c := qt.New(t) + + log := xlog.NewStdLog(xlog.Level(xlog.INFO)) + fakedbs := driver.NewTestHandler(log) + server, err := driver.MockMysqlServer(log, fakedbs) + c.Assert(err, qt.IsNil) + c.Cleanup(func() { server.Close() }) + + address := server.Addr() + database := "test`db" + table := "customer_report`$probe" + column := "display`name" + + selectResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: column, Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{ + {sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("ok"))}, + }, + } + + schemaResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "Table", Type: querypb.Type_VARCHAR}, + {Name: "Create Table", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{ + { + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(table)), + sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("CREATE TABLE `customer_report``$probe` (`display``name` varchar(255)) ENGINE=InnoDB")), + }, + }, + } + + tablesResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "Tables_in_test`db", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{ + {sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(table))}, + }, + } + + viewsResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "TABLE_NAME", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{}, + } + + fieldsResult := &sqltypes.Result{ + Fields: []*querypb.Field{ + {Name: "Field", Type: querypb.Type_VARCHAR}, + {Name: "Type", Type: querypb.Type_VARCHAR}, + {Name: "Null", Type: querypb.Type_VARCHAR}, + {Name: "Key", Type: querypb.Type_VARCHAR}, + {Name: "Default", Type: querypb.Type_VARCHAR}, + {Name: "Extra", Type: querypb.Type_VARCHAR}, + }, + Rows: [][]sqltypes.Value{ + testRow(column, ""), + }, + } + + fakedbs.AddQueryPattern("use .*", &sqltypes.Result{}) + fakedbs.AddQueryPattern("show create table .*", schemaResult) + fakedbs.AddQueryPattern("show tables from .*", tablesResult) + fakedbs.AddQueryPattern("select table_name .* from information_schema.tables .*", viewsResult) + fakedbs.AddQueryPattern("show fields from .*", fieldsResult) + fakedbs.AddQueryPattern("select .* from .*", selectResult) + fakedbs.AddQueryPattern("set .*", &sqltypes.Result{}) + + cfg := &Config{ + Database: database, + Outdir: c.TempDir(), + User: "mock", + Password: "mock", + Address: address, + ChunksizeInMB: 1, + Threads: 16, + StmtSize: 10000, + IntervalMs: 500, + SessionVars: []string{"SET @@radon_streaming_fetch='ON'"}, + } + + d, err := NewDumper(cfg) + c.Assert(err, qt.IsNil) + + err = d.Run(context.Background()) + c.Assert(err, qt.IsNil) + + c.Assert(fakedbs.GetQueryCalledNum("SHOW TABLES FROM `test``db`"), qt.Equals, 1) + c.Assert(fakedbs.GetQueryCalledNum("SHOW CREATE TABLE `test``db`.`customer_report``$probe`"), qt.Equals, 1) + c.Assert(fakedbs.GetQueryCalledNum("SHOW FIELDS FROM `customer_report``$probe`"), qt.Equals, 1) + c.Assert(fakedbs.GetQueryCalledNum("SELECT `display``name` FROM `test``db`.`customer_report``$probe` "), qt.Equals, 1) + + dat, err := os.ReadFile(cfg.Outdir + "/" + database + "." + table + ".00001.sql") + c.Assert(err, qt.IsNil) + c.Assert(string(dat), qt.Contains, "INSERT INTO `customer_report``$probe`(`display``name`) VALUES") +} + func TestDumperColumnIncludes(t *testing.T) { c := qt.New(t) diff --git a/internal/dumper/loader.go b/internal/dumper/loader.go index 9d9bf626..bda9dd65 100644 --- a/internal/dumper/loader.go +++ b/internal/dumper/loader.go @@ -270,7 +270,7 @@ func (l *Loader) restoreTableSchema(overwrite bool, tables []string, conn *Conne name := strings.TrimSuffix(base, schemaSuffix) db := l.databaseNameFromFilename(name) tbl := strings.Split(name, ".")[1] - name = fmt.Sprintf("`%v`.`%v`", db, tbl) + name = fmt.Sprintf("%s.%s", quoteIdentifier(db), quoteIdentifier(tbl)) l.log.Info( "working table", @@ -278,7 +278,7 @@ func (l *Loader) restoreTableSchema(overwrite bool, tables []string, conn *Conne zap.String("table ", tbl), ) - err := conn.Execute(fmt.Sprintf("USE `%s`", db)) + err := conn.Execute(fmt.Sprintf("USE %s", quoteIdentifier(db))) if err != nil { return err } @@ -386,7 +386,7 @@ func (l *Loader) restoreViews(overwrite bool, views []string, conn *Connection) name := strings.TrimSuffix(base, viewSuffix) db := strings.Split(name, ".")[0] view := strings.Split(name, ".")[1] - name = fmt.Sprintf("`%v`.`%v`", db, view) + name = fmt.Sprintf("%s.%s", quoteIdentifier(db), quoteIdentifier(view)) l.log.Info( "working view", @@ -394,7 +394,7 @@ func (l *Loader) restoreViews(overwrite bool, views []string, conn *Connection) zap.String("view ", view), ) - err := conn.Execute(fmt.Sprintf("USE `%s`", db)) + err := conn.Execute(fmt.Sprintf("USE %s", quoteIdentifier(db))) if err != nil { return err } @@ -483,7 +483,7 @@ func (l *Loader) restoreTable(ctx context.Context, table string, conn *Connectio zap.Int("thread_conn_id", conn.ID), ) - err := conn.Execute(fmt.Sprintf("USE `%s`", db)) + err := conn.Execute(fmt.Sprintf("USE %s", quoteIdentifier(db))) if err != nil { return 0, err } diff --git a/internal/dumper/loader_test.go b/internal/dumper/loader_test.go index 4182434f..2cbafbf5 100644 --- a/internal/dumper/loader_test.go +++ b/internal/dumper/loader_test.go @@ -231,3 +231,137 @@ func TestRestoreTableSchema_DropTableCalledOnce(t *testing.T) { err = loader.restoreTableSchema(cfg.OverwriteTables, []string{schemaFile}, conn) c.Assert(err, qt.IsNil, qt.Commentf("DROP TABLE should be called exactly once. If called multiple times, this test will fail.")) } + +func TestRestoreTableSchema_EscapesGeneratedIdentifiers(t *testing.T) { + c := qt.New(t) + + log := xlog.NewStdLog(xlog.Level(xlog.ERROR)) + fakedbs := driver.NewTestHandler(log) + server, err := driver.MockMysqlServer(log, fakedbs) + c.Assert(err, qt.IsNil) + defer server.Close() + + address := server.Addr() + fakedbs.AddQuery("USE `test``db`", &sqltypes.Result{}) + fakedbs.AddQuery("SET FOREIGN_KEY_CHECKS=0", &sqltypes.Result{}) + fakedbs.AddQuery("DROP TABLE IF EXISTS `test``db`.`customer_report``$probe`", &sqltypes.Result{}) + fakedbs.AddQuery("CREATE TABLE `customer_report``$probe` (\n id INT PRIMARY KEY\n)", &sqltypes.Result{}) + + tempDir := c.TempDir() + schemaFile := tempDir + "/test`db.customer_report`$probe-schema.sql" + schemaContent := "CREATE TABLE `customer_report``$probe` (\n id INT PRIMARY KEY\n);" + err = os.WriteFile(schemaFile, []byte(schemaContent), 0644) + c.Assert(err, qt.IsNil) + + cfg := &Config{ + Database: "test`db", + Outdir: tempDir, + User: "mock", + Password: "mock", + Threads: 1, + Address: address, + IntervalMs: 500, + OverwriteTables: true, + } + loader, err := NewLoader(cfg) + c.Assert(err, qt.IsNil) + + pool, err := NewPool(loader.log, cfg.Threads, cfg.Address, cfg.User, cfg.Password, cfg.SessionVars, "") + c.Assert(err, qt.IsNil) + defer pool.Close() + + conn := pool.Get() + defer pool.Put(conn) + + err = loader.restoreTableSchema(cfg.OverwriteTables, []string{schemaFile}, conn) + c.Assert(err, qt.IsNil) +} + +func TestRestoreViews_EscapesGeneratedIdentifiers(t *testing.T) { + c := qt.New(t) + + log := xlog.NewStdLog(xlog.Level(xlog.ERROR)) + fakedbs := driver.NewTestHandler(log) + server, err := driver.MockMysqlServer(log, fakedbs) + c.Assert(err, qt.IsNil) + defer server.Close() + + address := server.Addr() + fakedbs.AddQuery("USE `test``db`", &sqltypes.Result{}) + fakedbs.AddQuery("SET FOREIGN_KEY_CHECKS=0", &sqltypes.Result{}) + fakedbs.AddQuery("DROP VIEW IF EXISTS `test``db`.`report``view`", &sqltypes.Result{}) + fakedbs.AddQueryPattern("create view .*", &sqltypes.Result{}) + + tempDir := c.TempDir() + viewFile := tempDir + "/test`db.report`view-schema-view.sql" + viewContent := "CREATE VIEW `report``view` AS SELECT 1;" + err = os.WriteFile(viewFile, []byte(viewContent), 0644) + c.Assert(err, qt.IsNil) + + cfg := &Config{ + Database: "test`db", + Outdir: tempDir, + User: "mock", + Password: "mock", + Threads: 1, + Address: address, + IntervalMs: 500, + OverwriteTables: true, + } + loader, err := NewLoader(cfg) + c.Assert(err, qt.IsNil) + + pool, err := NewPool(loader.log, cfg.Threads, cfg.Address, cfg.User, cfg.Password, cfg.SessionVars, "") + c.Assert(err, qt.IsNil) + defer pool.Close() + + conn := pool.Get() + defer pool.Put(conn) + + err = loader.restoreViews(cfg.OverwriteTables, []string{viewFile}, conn) + c.Assert(err, qt.IsNil) +} + +func TestRestoreTable_EscapesUseDatabase(t *testing.T) { + c := qt.New(t) + + log := xlog.NewStdLog(xlog.Level(xlog.ERROR)) + fakedbs := driver.NewTestHandler(log) + server, err := driver.MockMysqlServer(log, fakedbs) + c.Assert(err, qt.IsNil) + defer server.Close() + + address := server.Addr() + fakedbs.AddQuery("USE `test``db`", &sqltypes.Result{}) + fakedbs.AddQuery("SET FOREIGN_KEY_CHECKS=0", &sqltypes.Result{}) + fakedbs.AddQuery("INSERT INTO `customer_report``$probe` VALUES (1)", &sqltypes.Result{}) + + tempDir := c.TempDir() + dataFile := tempDir + "/test`db.customer_report`$probe.00001.sql" + dataContent := "INSERT INTO `customer_report``$probe` VALUES (1);" + err = os.WriteFile(dataFile, []byte(dataContent), 0644) + c.Assert(err, qt.IsNil) + + cfg := &Config{ + Database: "test`db", + Outdir: tempDir, + User: "mock", + Password: "mock", + Threads: 1, + Address: address, + IntervalMs: 500, + MaxQuerySize: 1024, + } + loader, err := NewLoader(cfg) + c.Assert(err, qt.IsNil) + + pool, err := NewPool(loader.log, cfg.Threads, cfg.Address, cfg.User, cfg.Password, cfg.SessionVars, "") + c.Assert(err, qt.IsNil) + defer pool.Close() + + conn := pool.Get() + defer pool.Put(conn) + + _, err = loader.restoreTable(context.Background(), dataFile, conn) + c.Assert(err, qt.IsNil) +} diff --git a/internal/dumper/sql_writer.go b/internal/dumper/sql_writer.go index 541e2f29..3ab6222a 100644 --- a/internal/dumper/sql_writer.go +++ b/internal/dumper/sql_writer.go @@ -21,7 +21,7 @@ type sqlWriter struct { func newSQLWriter(cfg *Config, table string) *sqlWriter { return &sqlWriter{ cfg: cfg, - table: table, + table: quoteIdentifier(table), rows: make([]string, 0, 256), inserts: make([]string, 0, 256), } @@ -30,7 +30,7 @@ func newSQLWriter(cfg *Config, table string) *sqlWriter { func (w *sqlWriter) Initialize(fieldNames []string) error { w.fields = make([]string, len(fieldNames)) for i, name := range fieldNames { - w.fields[i] = fmt.Sprintf("`%s`", name) + w.fields[i] = quoteIdentifier(name) } return nil } @@ -58,7 +58,7 @@ func (w *sqlWriter) WriteRow(row []sqltypes.Value) (int, error) { w.chunkbytes += rowBytes if w.stmtsize >= w.cfg.StmtSize { - insertone := fmt.Sprintf("INSERT INTO `%s`(%s) VALUES\n%s", w.table, strings.Join(w.fields, ","), strings.Join(w.rows, ",\n")) + insertone := fmt.Sprintf("INSERT INTO %s(%s) VALUES\n%s", w.table, strings.Join(w.fields, ","), strings.Join(w.rows, ",\n")) w.inserts = append(w.inserts, insertone) w.rows = w.rows[:0] w.stmtsize = 0 @@ -87,7 +87,7 @@ func (w *sqlWriter) Flush(outdir, database, table string, fileNo int) error { func (w *sqlWriter) Close(outdir, database, table string, fileNo int) error { if w.chunkbytes > 0 { if len(w.rows) > 0 { - insertone := fmt.Sprintf("INSERT INTO `%s`(%s) VALUES\n%s", w.table, strings.Join(w.fields, ","), strings.Join(w.rows, ",\n")) + insertone := fmt.Sprintf("INSERT INTO %s(%s) VALUES\n%s", w.table, strings.Join(w.fields, ","), strings.Join(w.rows, ",\n")) w.inserts = append(w.inserts, insertone) } return w.Flush(outdir, database, table, fileNo)