diff --git a/go/logic/hooks.go b/go/logic/hooks.go index 0ff296d..2543f8e 100644 --- a/go/logic/hooks.go +++ b/go/logic/hooks.go @@ -7,6 +7,7 @@ package logic import ( "fmt" + "io" "os" "os/exec" "path/filepath" @@ -34,18 +35,16 @@ const ( type HooksExecutor struct { migrationContext *base.MigrationContext + writer io.Writer } func NewHooksExecutor(migrationContext *base.MigrationContext) *HooksExecutor { return &HooksExecutor{ migrationContext: migrationContext, + writer: os.Stderr, } } -func (this *HooksExecutor) initHooks() error { - return nil -} - func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) []string { env := os.Environ() env = append(env, fmt.Sprintf("GH_OST_DATABASE_NAME=%s", this.migrationContext.DatabaseName)) @@ -76,13 +75,13 @@ func (this *HooksExecutor) applyEnvironmentVariables(extraVariables ...string) [ } // executeHook executes a command, and sets relevant environment variables -// combined output & error are printed to gh-ost's standard error. +// combined output & error are printed to the configured writer. func (this *HooksExecutor) executeHook(hook string, extraVariables ...string) error { cmd := exec.Command(hook) cmd.Env = this.applyEnvironmentVariables(extraVariables...) combinedOutput, err := cmd.CombinedOutput() - fmt.Fprintln(os.Stderr, string(combinedOutput)) + fmt.Fprintln(this.writer, string(combinedOutput)) return log.Errore(err) } diff --git a/go/logic/hooks_test.go b/go/logic/hooks_test.go new file mode 100644 index 0000000..3b28afe --- /dev/null +++ b/go/logic/hooks_test.go @@ -0,0 +1,113 @@ +/* + Copyright 2022 GitHub Inc. + See https://github.com/github/gh-ost/blob/master/LICENSE +*/ + +package logic + +import ( + "bufio" + "bytes" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/openark/golib/tests" + + "github.com/github/gh-ost/go/base" +) + +func TestHooksExecutorExecuteHooks(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.AlterStatement = "ENGINE=InnoDB" + migrationContext.DatabaseName = "test" + migrationContext.Hostname = "test.example.com" + migrationContext.OriginalTableName = "tablename" + migrationContext.RowsDeltaEstimate = 1 + migrationContext.RowsEstimate = 122 + migrationContext.TotalRowsCopied = 123456 + migrationContext.SetETADuration(time.Minute) + migrationContext.SetProgressPct(50) + hooksExecutor := NewHooksExecutor(migrationContext) + + writeTmpHookFunc := func(testName, hookName, script string) (path string, err error) { + if path, err = os.MkdirTemp("", testName); err != nil { + return path, err + } + err = os.WriteFile(filepath.Join(path, hookName), []byte(script), 0777) + return path, err + } + + t.Run("does-not-exist", func(t *testing.T) { + migrationContext.HooksPath = "/does/not/exist" + tests.S(t).ExpectNil(hooksExecutor.executeHooks("test-hook")) + }) + + t.Run("failed", func(t *testing.T) { + var err error + if migrationContext.HooksPath, err = writeTmpHookFunc( + "TestHooksExecutorExecuteHooks-failed", + "failed-hook", + "#!/bin/sh\nexit 1", + ); err != nil { + panic(err) + } + defer os.RemoveAll(migrationContext.HooksPath) + tests.S(t).ExpectNotNil(hooksExecutor.executeHooks("failed-hook")) + }) + + t.Run("success", func(t *testing.T) { + var err error + if migrationContext.HooksPath, err = writeTmpHookFunc( + "TestHooksExecutorExecuteHooks-success", + "success-hook", + "#!/bin/sh\nenv", + ); err != nil { + panic(err) + } + defer os.RemoveAll(migrationContext.HooksPath) + + var buf bytes.Buffer + hooksExecutor.writer = &buf + tests.S(t).ExpectNil(hooksExecutor.executeHooks("success-hook", "TEST="+t.Name())) + + scanner := bufio.NewScanner(&buf) + for scanner.Scan() { + split := strings.SplitN(scanner.Text(), "=", 2) + switch split[0] { + case "GH_OST_COPIED_ROWS": + copiedRows, _ := strconv.ParseInt(split[1], 10, 64) + tests.S(t).ExpectEquals(copiedRows, migrationContext.TotalRowsCopied) + case "GH_OST_DATABASE_NAME": + tests.S(t).ExpectEquals(split[1], migrationContext.DatabaseName) + case "GH_OST_DDL": + tests.S(t).ExpectEquals(split[1], migrationContext.AlterStatement) + case "GH_OST_DRY_RUN": + tests.S(t).ExpectEquals(split[1], "false") + case "GH_OST_ESTIMATED_ROWS": + estimatedRows, _ := strconv.ParseInt(split[1], 10, 64) + tests.S(t).ExpectEquals(estimatedRows, int64(123)) + case "GH_OST_ETA_SECONDS": + etaSeconds, _ := strconv.ParseInt(split[1], 10, 64) + tests.S(t).ExpectEquals(etaSeconds, int64(60)) + case "GH_OST_EXECUTING_HOST": + tests.S(t).ExpectEquals(split[1], migrationContext.Hostname) + case "GH_OST_GHOST_TABLE_NAME": + tests.S(t).ExpectEquals(split[1], fmt.Sprintf("_%s_gho", migrationContext.OriginalTableName)) + case "GH_OST_OLD_TABLE_NAME": + tests.S(t).ExpectEquals(split[1], fmt.Sprintf("_%s_del", migrationContext.OriginalTableName)) + case "GH_OST_PROGRESS": + progress, _ := strconv.ParseFloat(split[1], 64) + tests.S(t).ExpectEquals(progress, 50.0) + case "GH_OST_TABLE_NAME": + tests.S(t).ExpectEquals(split[1], migrationContext.OriginalTableName) + case "TEST": + tests.S(t).ExpectEquals(split[1], t.Name()) + } + } + }) +} diff --git a/go/logic/migrator.go b/go/logic/migrator.go index 865814d..ccaf160 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -98,6 +98,7 @@ type Migrator struct { func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { migrator := &Migrator{ appVersion: appVersion, + hooksExecutor: NewHooksExecutor(context), migrationContext: context, parser: sql.NewAlterTableParser(), ghostTableMigrated: make(chan bool), @@ -113,15 +114,6 @@ func NewMigrator(context *base.MigrationContext, appVersion string) *Migrator { return migrator } -// initiateHooksExecutor -func (this *Migrator) initiateHooksExecutor() (err error) { - this.hooksExecutor = NewHooksExecutor(this.migrationContext) - if err := this.hooksExecutor.initHooks(); err != nil { - return err - } - return nil -} - // sleepWhileTrue sleeps indefinitely until the given function returns 'false' // (or fails with error) func (this *Migrator) sleepWhileTrue(operation func() (bool, error)) error { @@ -342,9 +334,6 @@ func (this *Migrator) Migrate() (err error) { go this.listenOnPanicAbort() - if err := this.initiateHooksExecutor(); err != nil { - return err - } if err := this.hooksExecutor.onStartup(); err != nil { return err } diff --git a/script/test b/script/test index 7e757b5..5c32b37 100755 --- a/script/test +++ b/script/test @@ -14,4 +14,4 @@ script/build cd .gopath/src/github.com/github/gh-ost echo "Running unit tests" -go test ./go/... +go test -v -covermode=atomic ./go/...