Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions internal/cmd/database/dump.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,13 @@ func dump(ch *cmdutil.Helper, cmd *cobra.Command, flags *dumpFlags, args []strin

if flags.shard != "" {
if flags.replica {
useCmd := fmt.Sprintf("USE `%s/%s@replica`;", dbName, flags.shard)
useCmd := shardUseCommand(dbName, flags.shard, flags.replica, flags.rdonly)
cfg.SessionVars = append([]string{useCmd}, cfg.SessionVars...)
} else if flags.rdonly {
useCmd := fmt.Sprintf("USE `%s/%s@rdonly`;", dbName, flags.shard)
useCmd := shardUseCommand(dbName, flags.shard, flags.replica, flags.rdonly)
cfg.SessionVars = append([]string{useCmd}, cfg.SessionVars...)
} else {
useCmd := fmt.Sprintf("USE `%s/%s`;", dbName, flags.shard)
useCmd := shardUseCommand(dbName, flags.shard, flags.replica, flags.rdonly)
cfg.SessionVars = append([]string{useCmd}, cfg.SessionVars...)
}
}
Expand Down Expand Up @@ -392,3 +392,17 @@ func parseColumnIncludes(columns []string) (map[string]map[string]bool, error) {

return result, nil
}

func shardUseCommand(dbName string, shard string, replica bool, rdonly bool) string {
target := fmt.Sprintf("%s/%s", dbName, shard)
if replica {
target += "@replica"
} else if rdonly {
target += "@rdonly"
}
return fmt.Sprintf("USE %s;", quoteIdentifier(target))
}

func quoteIdentifier(identifier string) string {
return "`" + strings.ReplaceAll(identifier, "`", "``") + "`"
}
9 changes: 9 additions & 0 deletions internal/cmd/database/dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,12 @@ func TestParseColumnIncludes(t *testing.T) {
})
}
}

func TestShardUseCommand(t *testing.T) {
c := qt.New(t)

c.Assert(shardUseCommand("commerce", "-80", false, false), qt.Equals, "USE `commerce/-80`;")
c.Assert(shardUseCommand("commerce", "-80", true, false), qt.Equals, "USE `commerce/-80@replica`;")
c.Assert(shardUseCommand("commerce", "-80", false, true), qt.Equals, "USE `commerce/-80@rdonly`;")
c.Assert(shardUseCommand("key`space", "sh`ard", false, false), qt.Equals, "USE `key``space/sh``ard`;")
}
32 changes: 20 additions & 12 deletions internal/dumper/dumper.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func writeMetaData(outdir string) error {
}

func (d *Dumper) dumpTableSchema(conn *Connection, database string, table string, views map[string]bool) error {
qr, err := conn.Fetch(fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", database, table))
qr, err := conn.Fetch(fmt.Sprintf("SHOW CREATE TABLE %s.%s", quoteIdentifier(database), quoteIdentifier(table)))
if err != nil {
return err
}
Expand Down Expand Up @@ -293,7 +293,7 @@ func (d *Dumper) dumpTable(ctx context.Context, conn *Connection, database strin
return err
}

cursor, err := conn.StreamFetch(fmt.Sprintf("SELECT %s FROM `%s`.`%s` %s", strings.Join(dumpCtx.selfields, ", "), database, table, dumpCtx.where))
cursor, err := conn.StreamFetch(fmt.Sprintf("SELECT %s FROM %s.%s %s", strings.Join(dumpCtx.selfields, ", "), quoteIdentifier(database), quoteIdentifier(table), dumpCtx.where))
if err != nil {
return err
}
Expand Down Expand Up @@ -387,9 +387,9 @@ func (d *Dumper) tableDumpContext(conn *Connection, table string) (*dumpContext,
ctx.fieldNames = append(ctx.fieldNames, name)
replacement, ok := d.cfg.Selects[table][name]
if ok {
ctx.selfields = append(ctx.selfields, fmt.Sprintf("%s AS `%s`", replacement, name))
ctx.selfields = append(ctx.selfields, fmt.Sprintf("%s AS %s", replacement, quoteIdentifier(name)))
} else {
ctx.selfields = append(ctx.selfields, fmt.Sprintf("`%s`", name))
ctx.selfields = append(ctx.selfields, quoteIdentifier(name))
}
}

Expand All @@ -406,7 +406,7 @@ func (d *Dumper) tableDumpContext(conn *Connection, table string) (*dumpContext,
}

func (d *Dumper) allTables(conn *Connection, database string) ([]string, error) {
qr, err := conn.Fetch(fmt.Sprintf("SHOW TABLES FROM `%s`", database))
qr, err := conn.Fetch(fmt.Sprintf("SHOW TABLES FROM %s", quoteIdentifier(database)))
if err != nil {
return nil, err
}
Expand All @@ -419,12 +419,12 @@ func (d *Dumper) allTables(conn *Connection, database string) ([]string, error)
}

func (d *Dumper) allViews(conn *Connection, database string) (map[string]bool, error) {
query := `SELECT TABLE_NAME
FROM information_schema.TABLES
WHERE TABLE_SCHEMA LIKE '%s'
AND TABLE_TYPE = 'VIEW'
`
qr, err := conn.Fetch(fmt.Sprintf(query, database))
query := "SELECT TABLE_NAME \n" +
"\t\t\t FROM information_schema.TABLES \n" +
"\t\t\t WHERE TABLE_SCHEMA LIKE %s \n" +
"\t\t\t AND TABLE_TYPE = 'VIEW'\n" +
"\t\t\t"
qr, err := conn.Fetch(fmt.Sprintf(query, quoteStringLiteral(database)))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -466,7 +466,7 @@ func (d *Dumper) filterDatabases(conn *Connection, filter *regexp.Regexp, invert

// dumpableFieldNames returns a slice that contains valid field names for the dump.
func (d *Dumper) dumpableFieldNames(conn *Connection, table string) ([]string, error) {
qr, err := conn.Fetch(fmt.Sprintf("SHOW FIELDS FROM `%s`", table))
qr, err := conn.Fetch(fmt.Sprintf("SHOW FIELDS FROM %s", quoteIdentifier(table)))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -515,6 +515,14 @@ func writeFile(file string, data string) error {
return nil
}

func quoteIdentifier(identifier string) string {
return "`" + strings.ReplaceAll(identifier, "`", "``") + "`"
}

func quoteStringLiteral(s string) string {
return "'" + strings.ReplaceAll(s, "'", "''") + "'"
}

// escapeBytes used to escape the literal byte.
// See https://dev.mysql.com/doc/refman/5.7/en/string-literals.html
// for more information on how to escape string literals in MySQL.
Expand Down
118 changes: 118 additions & 0 deletions internal/dumper/dumper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1574,6 +1574,124 @@ func TestEscapeBytes(t *testing.T) {
}
}

func TestQuoteIdentifier(t *testing.T) {
c := qt.New(t)

c.Assert(quoteIdentifier("simple"), qt.Equals, "`simple`")
c.Assert(quoteIdentifier("customer_report`$probe"), qt.Equals, "`customer_report``$probe`")
c.Assert(quoteIdentifier("display`name"), qt.Equals, "`display``name`")
}

func TestQuoteStringLiteral(t *testing.T) {
c := qt.New(t)

c.Assert(quoteStringLiteral("simple"), qt.Equals, "'simple'")
c.Assert(quoteStringLiteral("test'db"), qt.Equals, "'test''db'")
}

func TestDumperEscapesDiscoveredIdentifiers(t *testing.T) {
c := qt.New(t)

log := xlog.NewStdLog(xlog.Level(xlog.INFO))
fakedbs := driver.NewTestHandler(log)
server, err := driver.MockMysqlServer(log, fakedbs)
c.Assert(err, qt.IsNil)
c.Cleanup(func() { server.Close() })

address := server.Addr()
database := "test`db"
table := "customer_report`$probe"
column := "display`name"

selectResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: column, Type: querypb.Type_VARCHAR},
},
Rows: [][]sqltypes.Value{
{sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("ok"))},
},
}

schemaResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "Table", Type: querypb.Type_VARCHAR},
{Name: "Create Table", Type: querypb.Type_VARCHAR},
},
Rows: [][]sqltypes.Value{
{
sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(table)),
sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte("CREATE TABLE `customer_report``$probe` (`display``name` varchar(255)) ENGINE=InnoDB")),
},
},
}

tablesResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "Tables_in_test`db", Type: querypb.Type_VARCHAR},
},
Rows: [][]sqltypes.Value{
{sqltypes.MakeTrusted(querypb.Type_VARCHAR, []byte(table))},
},
}

viewsResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "TABLE_NAME", Type: querypb.Type_VARCHAR},
},
Rows: [][]sqltypes.Value{},
}

fieldsResult := &sqltypes.Result{
Fields: []*querypb.Field{
{Name: "Field", Type: querypb.Type_VARCHAR},
{Name: "Type", Type: querypb.Type_VARCHAR},
{Name: "Null", Type: querypb.Type_VARCHAR},
{Name: "Key", Type: querypb.Type_VARCHAR},
{Name: "Default", Type: querypb.Type_VARCHAR},
{Name: "Extra", Type: querypb.Type_VARCHAR},
},
Rows: [][]sqltypes.Value{
testRow(column, ""),
},
}

fakedbs.AddQueryPattern("use .*", &sqltypes.Result{})
fakedbs.AddQueryPattern("show create table .*", schemaResult)
fakedbs.AddQueryPattern("show tables from .*", tablesResult)
fakedbs.AddQueryPattern("select table_name .* from information_schema.tables .*", viewsResult)
fakedbs.AddQueryPattern("show fields from .*", fieldsResult)
fakedbs.AddQueryPattern("select .* from .*", selectResult)
fakedbs.AddQueryPattern("set .*", &sqltypes.Result{})

cfg := &Config{
Database: database,
Outdir: c.TempDir(),
User: "mock",
Password: "mock",
Address: address,
ChunksizeInMB: 1,
Threads: 16,
StmtSize: 10000,
IntervalMs: 500,
SessionVars: []string{"SET @@radon_streaming_fetch='ON'"},
}

d, err := NewDumper(cfg)
c.Assert(err, qt.IsNil)

err = d.Run(context.Background())
c.Assert(err, qt.IsNil)

c.Assert(fakedbs.GetQueryCalledNum("SHOW TABLES FROM `test``db`"), qt.Equals, 1)
c.Assert(fakedbs.GetQueryCalledNum("SHOW CREATE TABLE `test``db`.`customer_report``$probe`"), qt.Equals, 1)
c.Assert(fakedbs.GetQueryCalledNum("SHOW FIELDS FROM `customer_report``$probe`"), qt.Equals, 1)
c.Assert(fakedbs.GetQueryCalledNum("SELECT `display``name` FROM `test``db`.`customer_report``$probe` "), qt.Equals, 1)

dat, err := os.ReadFile(cfg.Outdir + "/" + database + "." + table + ".00001.sql")
c.Assert(err, qt.IsNil)
c.Assert(string(dat), qt.Contains, "INSERT INTO `customer_report``$probe`(`display``name`) VALUES")
}

func TestDumperColumnIncludes(t *testing.T) {
c := qt.New(t)

Expand Down
10 changes: 5 additions & 5 deletions internal/dumper/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,15 +270,15 @@ func (l *Loader) restoreTableSchema(overwrite bool, tables []string, conn *Conne
name := strings.TrimSuffix(base, schemaSuffix)
db := l.databaseNameFromFilename(name)
tbl := strings.Split(name, ".")[1]
name = fmt.Sprintf("`%v`.`%v`", db, tbl)
name = fmt.Sprintf("%s.%s", quoteIdentifier(db), quoteIdentifier(tbl))

l.log.Info(
"working table",
zap.String("database", db),
zap.String("table ", tbl),
)

err := conn.Execute(fmt.Sprintf("USE `%s`", db))
err := conn.Execute(fmt.Sprintf("USE %s", quoteIdentifier(db)))
if err != nil {
return err
}
Expand Down Expand Up @@ -386,15 +386,15 @@ func (l *Loader) restoreViews(overwrite bool, views []string, conn *Connection)
name := strings.TrimSuffix(base, viewSuffix)
db := strings.Split(name, ".")[0]
view := strings.Split(name, ".")[1]
name = fmt.Sprintf("`%v`.`%v`", db, view)
name = fmt.Sprintf("%s.%s", quoteIdentifier(db), quoteIdentifier(view))

l.log.Info(
"working view",
zap.String("database", db),
zap.String("view ", view),
)

err := conn.Execute(fmt.Sprintf("USE `%s`", db))
err := conn.Execute(fmt.Sprintf("USE %s", quoteIdentifier(db)))
if err != nil {
return err
}
Expand Down Expand Up @@ -483,7 +483,7 @@ func (l *Loader) restoreTable(ctx context.Context, table string, conn *Connectio
zap.Int("thread_conn_id", conn.ID),
)

err := conn.Execute(fmt.Sprintf("USE `%s`", db))
err := conn.Execute(fmt.Sprintf("USE %s", quoteIdentifier(db)))
if err != nil {
return 0, err
}
Expand Down
Loading