diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index e681e5a1b..733cf2601 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -19,6 +19,24 @@ Here are a few things you can do that will increase the likelihood of your pull - Keep your change as focused as possible. If there are multiple changes you would like to make that are not dependent upon each other, consider submitting them as separate pull requests. - Write a [good commit message](http://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html). +## Development Guidelines + +### Channel Safety + +When working with channels in goroutines, it's critical to prevent deadlocks that can occur when a channel receiver exits due to an error while senders are still trying to send values. Always use `base.SendWithContext` for channel sends to avoid deadlocks: + +```go +// ✅ CORRECT - Uses helper to prevent deadlock +if err := base.SendWithContext(ctx, ch, value); err != nil { + return err // context was cancelled +} + +// ❌ WRONG - Can deadlock if receiver exits +ch <- value +``` + +Even if the destination channel is buffered, deadlocks could still occur if the buffer fills up and the receiver exits, so it's important to use `SendWithContext` in those cases as well. + ## Resources - [Contributing to Open Source on GitHub](https://guides.github.com/activities/contributing-to-open-source/) diff --git a/go/base/context.go b/go/base/context.go index 891e27fef..2c8d28d56 100644 --- a/go/base/context.go +++ b/go/base/context.go @@ -6,6 +6,7 @@ package base import ( + "context" "fmt" "math" "os" @@ -225,6 +226,16 @@ type MigrationContext struct { InCutOverCriticalSectionFlag int64 PanicAbort chan error + // Context for cancellation signaling across all goroutines + // Stored in struct as it spans the entire migration lifecycle, not per-function. + // context.Context is safe for concurrent use by multiple goroutines. + ctx context.Context //nolint:containedctx + cancelFunc context.CancelFunc + + // Stores the fatal error that triggered abort + AbortError error + abortMutex *sync.Mutex + OriginalTableColumnsOnApplier *sql.ColumnList OriginalTableColumns *sql.ColumnList OriginalTableVirtualColumns *sql.ColumnList @@ -293,6 +304,7 @@ type ContextConfig struct { } func NewMigrationContext() *MigrationContext { + ctx, cancelFunc := context.WithCancel(context.Background()) return &MigrationContext{ Uuid: uuid.NewString(), defaultNumRetries: 60, @@ -313,6 +325,9 @@ func NewMigrationContext() *MigrationContext { lastHeartbeatOnChangelogMutex: &sync.Mutex{}, ColumnRenameMap: make(map[string]string), PanicAbort: make(chan error), + ctx: ctx, + cancelFunc: cancelFunc, + abortMutex: &sync.Mutex{}, Log: NewDefaultLogger(), } } @@ -982,3 +997,54 @@ func (this *MigrationContext) GetGhostTriggerName(triggerName string) string { func (this *MigrationContext) ValidateGhostTriggerLengthBelowMaxLength(triggerName string) bool { return utf8.RuneCountInString(triggerName) <= mysql.MaxTableNameLength } + +// GetContext returns the migration context for cancellation checking +func (this *MigrationContext) GetContext() context.Context { + return this.ctx +} + +// SetAbortError stores the fatal error that triggered abort +// Only the first error is stored (subsequent errors are ignored) +func (this *MigrationContext) SetAbortError(err error) { + this.abortMutex.Lock() + defer this.abortMutex.Unlock() + if this.AbortError == nil { + this.AbortError = err + } +} + +// GetAbortError retrieves the stored abort error +func (this *MigrationContext) GetAbortError() error { + this.abortMutex.Lock() + defer this.abortMutex.Unlock() + return this.AbortError +} + +// CancelContext cancels the migration context to signal all goroutines to stop +// The cancel function is safe to call multiple times and from multiple goroutines. +func (this *MigrationContext) CancelContext() { + if this.cancelFunc != nil { + this.cancelFunc() + } +} + +// SendWithContext attempts to send a value to a channel, but returns early +// if the context is cancelled. This prevents goroutine deadlocks when the +// channel receiver has exited due to an error. +// +// Use this instead of bare channel sends (ch <- val) in goroutines to ensure +// proper cleanup when the migration is aborted. +// +// Example: +// +// if err := base.SendWithContext(ctx, ch, value); err != nil { +// return err // context was cancelled +// } +func SendWithContext[T any](ctx context.Context, ch chan<- T, val T) error { + select { + case ch <- val: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} diff --git a/go/base/context_test.go b/go/base/context_test.go index f8bce6f27..a9f62150d 100644 --- a/go/base/context_test.go +++ b/go/base/context_test.go @@ -6,8 +6,10 @@ package base import ( + "errors" "os" "strings" + "sync" "testing" "time" @@ -213,3 +215,58 @@ func TestReadConfigFile(t *testing.T) { } } } + +func TestSetAbortError_StoresFirstError(t *testing.T) { + ctx := NewMigrationContext() + + err1 := errors.New("first error") + err2 := errors.New("second error") + + ctx.SetAbortError(err1) + ctx.SetAbortError(err2) + + got := ctx.GetAbortError() + if got != err1 { //nolint:errorlint // Testing pointer equality for sentinel error + t.Errorf("Expected first error %v, got %v", err1, got) + } +} + +func TestSetAbortError_ThreadSafe(t *testing.T) { + ctx := NewMigrationContext() + + var wg sync.WaitGroup + errs := []error{ + errors.New("error 1"), + errors.New("error 2"), + errors.New("error 3"), + } + + // Launch 3 goroutines trying to set error concurrently + for _, err := range errs { + wg.Add(1) + go func(e error) { + defer wg.Done() + ctx.SetAbortError(e) + }(err) + } + + wg.Wait() + + // Should store exactly one of the errors + got := ctx.GetAbortError() + if got == nil { + t.Fatal("Expected error to be stored, got nil") + } + + // Verify it's one of the errors we sent + found := false + for _, err := range errs { + if got == err { //nolint:errorlint // Testing pointer equality for sentinel error + found = true + break + } + } + if !found { + t.Errorf("Stored error %v not in list of sent errors", got) + } +} diff --git a/go/logic/applier.go b/go/logic/applier.go index 58761d844..6c09dd61e 100644 --- a/go/logic/applier.go +++ b/go/logic/applier.go @@ -695,7 +695,17 @@ func (this *Applier) InitiateHeartbeat() { ticker := time.NewTicker(time.Duration(this.migrationContext.HeartbeatIntervalMilliseconds) * time.Millisecond) defer ticker.Stop() - for range ticker.C { + for { + // Check for context cancellation each iteration + ctx := this.migrationContext.GetContext() + select { + case <-ctx.Done(): + this.migrationContext.Log.Debugf("Heartbeat injection cancelled") + return + case <-ticker.C: + // Process heartbeat + } + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return } @@ -706,7 +716,8 @@ func (this *Applier) InitiateHeartbeat() { continue } if err := injectHeartbeat(); err != nil { - this.migrationContext.PanicAbort <- fmt.Errorf("injectHeartbeat writing failed %d times, last error: %w", numSuccessiveFailures, err) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("injectHeartbeat writing failed %d times, last error: %w", numSuccessiveFailures, err)) return } } diff --git a/go/logic/migrator.go b/go/logic/migrator.go index aa9a97c1c..ca5f5a729 100644 --- a/go/logic/migrator.go +++ b/go/logic/migrator.go @@ -163,7 +163,8 @@ func (this *Migrator) retryOperation(operation func() error, notFatalHint ...boo // there's an error. Let's try again. } if len(notFatalHint) == 0 { - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) } return err } @@ -191,7 +192,8 @@ func (this *Migrator) retryOperationWithExponentialBackoff(operation func() erro } } if len(notFatalHint) == 0 { - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) } return err } @@ -200,14 +202,19 @@ func (this *Migrator) retryOperationWithExponentialBackoff(operation func() erro // consumes and drops any further incoming events that may be left hanging. func (this *Migrator) consumeRowCopyComplete() { if err := <-this.rowCopyComplete; err != nil { - this.migrationContext.PanicAbort <- err + // Abort synchronously to ensure checkAbort() sees the error immediately + this.abort(err) + // Don't mark row copy as complete if there was an error + return } atomic.StoreInt64(&this.rowCopyCompleteFlag, 1) this.migrationContext.MarkRowCopyEndTime() go func() { for err := range this.rowCopyComplete { if err != nil { - this.migrationContext.PanicAbort <- err + // Abort synchronously to ensure the error is stored immediately + this.abort(err) + return } } }() @@ -238,14 +245,14 @@ func (this *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err e case Migrated, ReadMigrationRangeValues: // no-op event case GhostTableMigrated: - this.ghostTableMigrated <- true + // Use helper to prevent deadlock if migration aborts before receiver is ready + _ = base.SendWithContext(this.migrationContext.GetContext(), this.ghostTableMigrated, true) case AllEventsUpToLockProcessed: var applyEventFunc tableWriteFunc = func() error { - this.allEventsUpToLockProcessed <- &lockProcessedStruct{ + return base.SendWithContext(this.migrationContext.GetContext(), this.allEventsUpToLockProcessed, &lockProcessedStruct{ state: changelogStateString, coords: dmlEntry.Coordinates.Clone(), - } - return nil + }) } // at this point we know all events up to lock have been read from the streamer, // because the streamer works sequentially. So those events are either already handled, @@ -253,7 +260,8 @@ func (this *Migrator) onChangelogStateEvent(dmlEntry *binlog.BinlogEntry) (err e // So as not to create a potential deadlock, we write this func to applyEventsQueue // asynchronously, understanding it doesn't really matter. go func() { - this.applyEventsQueue <- newApplyEventStructByFunc(&applyEventFunc) + // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits + _ = base.SendWithContext(this.migrationContext.GetContext(), this.applyEventsQueue, newApplyEventStructByFunc(&applyEventFunc)) }() default: return fmt.Errorf("Unknown changelog state: %+v", changelogState) @@ -277,10 +285,24 @@ func (this *Migrator) onChangelogHeartbeatEvent(dmlEntry *binlog.BinlogEntry) (e } } -// listenOnPanicAbort aborts on abort request +// abort stores the error, cancels the context, and logs the abort. +// This is the common abort logic used by both listenOnPanicAbort and +// consumeRowCopyComplete to ensure consistent error handling. +func (this *Migrator) abort(err error) { + // Store the error for Migrate() to return + this.migrationContext.SetAbortError(err) + + // Cancel the context to signal all goroutines to stop + this.migrationContext.CancelContext() + + // Log the error (but don't panic or exit) + this.migrationContext.Log.Errorf("Migration aborted: %v", err) +} + +// listenOnPanicAbort listens for fatal errors and initiates graceful shutdown func (this *Migrator) listenOnPanicAbort() { err := <-this.migrationContext.PanicAbort - this.migrationContext.Log.Fatale(err) + this.abort(err) } // validateAlterStatement validates the `alter` statement meets criteria. @@ -348,10 +370,36 @@ func (this *Migrator) createFlagFiles() (err error) { return nil } +// checkAbort returns abort error if migration was aborted +func (this *Migrator) checkAbort() error { + if abortErr := this.migrationContext.GetAbortError(); abortErr != nil { + return abortErr + } + + ctx := this.migrationContext.GetContext() + if ctx != nil { + select { + case <-ctx.Done(): + // Context cancelled but no abort error stored yet + if abortErr := this.migrationContext.GetAbortError(); abortErr != nil { + return abortErr + } + return ctx.Err() + default: + // Not cancelled + } + } + return nil +} + // Migrate executes the complete migration logic. This is *the* major gh-ost function. func (this *Migrator) Migrate() (err error) { this.migrationContext.Log.Infof("Migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) this.migrationContext.StartTime = time.Now() + + // Ensure context is cancelled on exit (cleanup) + defer this.migrationContext.CancelContext() + if this.migrationContext.Hostname, err = os.Hostname(); err != nil { return err } @@ -375,6 +423,9 @@ func (this *Migrator) Migrate() (err error) { if err := this.initiateInspector(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } // If we are resuming, we will initiateStreaming later when we know // the binlog coordinates to resume streaming from. // If not resuming, the streamer must be initiated before the applier, @@ -383,10 +434,16 @@ func (this *Migrator) Migrate() (err error) { if err := this.initiateStreaming(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } } if err := this.initiateApplier(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } if err := this.createFlagFiles(); err != nil { return err } @@ -493,6 +550,10 @@ func (this *Migrator) Migrate() (err error) { this.migrationContext.Log.Debugf("Operating until row copy is complete") this.consumeRowCopyComplete() this.migrationContext.Log.Infof("Row copy complete") + // Check if row copy was aborted due to error + if err := this.checkAbort(); err != nil { + return err + } if err := this.hooksExecutor.onRowCopyComplete(); err != nil { return err } @@ -532,6 +593,10 @@ func (this *Migrator) Migrate() (err error) { return err } this.migrationContext.Log.Infof("Done migrating %s.%s", sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName)) + // Final check for abort before declaring success + if err := this.checkAbort(); err != nil { + return err + } return nil } @@ -543,6 +608,10 @@ func (this *Migrator) Revert() error { sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OriginalTableName), sql.EscapeName(this.migrationContext.DatabaseName), sql.EscapeName(this.migrationContext.OldTableName)) this.migrationContext.StartTime = time.Now() + + // Ensure context is cancelled on exit (cleanup) + defer this.migrationContext.CancelContext() + var err error if this.migrationContext.Hostname, err = os.Hostname(); err != nil { return err @@ -561,9 +630,15 @@ func (this *Migrator) Revert() error { if err := this.initiateInspector(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } if err := this.initiateApplier(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } if err := this.createFlagFiles(); err != nil { return err } @@ -588,6 +663,9 @@ func (this *Migrator) Revert() error { if err := this.initiateStreaming(); err != nil { return err } + if err := this.checkAbort(); err != nil { + return err + } if err := this.hooksExecutor.onValidated(); err != nil { return err } @@ -1293,7 +1371,8 @@ func (this *Migrator) initiateStreaming() error { this.migrationContext.Log.Debugf("Beginning streaming") err := this.eventsStreamer.StreamEvents(this.canStopStreaming) if err != nil { - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) } this.migrationContext.Log.Debugf("Done streaming") }() @@ -1319,8 +1398,9 @@ func (this *Migrator) addDMLEventsListener() error { this.migrationContext.DatabaseName, this.migrationContext.OriginalTableName, func(dmlEntry *binlog.BinlogEntry) error { - this.applyEventsQueue <- newApplyEventStructByDML(dmlEntry) - return nil + // Use helper to prevent deadlock if buffer fills and executeWriteFuncs exits + // This is critical because this callback blocks the event streamer + return base.SendWithContext(this.migrationContext.GetContext(), this.applyEventsQueue, newApplyEventStructByDML(dmlEntry)) }, ) return err @@ -1398,7 +1478,7 @@ func (this *Migrator) initiateApplier() error { // a chunk of rows onto the ghost table. func (this *Migrator) iterateChunks() error { terminateRowIteration := func(err error) error { - this.rowCopyComplete <- err + _ = base.SendWithContext(this.migrationContext.GetContext(), this.rowCopyComplete, err) return this.migrationContext.Log.Errore(err) } if this.migrationContext.Noop { @@ -1413,6 +1493,9 @@ func (this *Migrator) iterateChunks() error { var hasNoFurtherRangeFlag int64 // Iterate per chunk: for { + if err := this.checkAbort(); err != nil { + return terminateRowIteration(err) + } if atomic.LoadInt64(&this.rowCopyCompleteFlag) == 1 || atomic.LoadInt64(&hasNoFurtherRangeFlag) == 1 { // Done // There's another such check down the line @@ -1459,7 +1542,7 @@ func (this *Migrator) iterateChunks() error { this.migrationContext.Log.Infof("ApplyIterationInsertQuery has SQL warnings! %s", warning) } joinedWarnings := strings.Join(this.migrationContext.MigrationLastInsertSQLWarnings, "; ") - terminateRowIteration(fmt.Errorf("ApplyIterationInsertQuery failed because of SQL warnings: [%s]", joinedWarnings)) + return terminateRowIteration(fmt.Errorf("ApplyIterationInsertQuery failed because of SQL warnings: [%s]", joinedWarnings)) } } @@ -1482,7 +1565,14 @@ func (this *Migrator) iterateChunks() error { return nil } // Enqueue copy operation; to be executed by executeWriteFuncs() - this.copyRowsQueue <- copyRowsFunc + // Use helper to prevent deadlock if executeWriteFuncs exits + if err := base.SendWithContext(this.migrationContext.GetContext(), this.copyRowsQueue, copyRowsFunc); err != nil { + // Context cancelled, check for abort and exit + if abortErr := this.checkAbort(); abortErr != nil { + return terminateRowIteration(abortErr) + } + return terminateRowIteration(err) + } } } @@ -1563,20 +1653,18 @@ func (this *Migrator) Checkpoint(ctx context.Context) (*Checkpoint, error) { this.applier.LastIterationRangeMutex.Unlock() for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - default: - this.applier.CurrentCoordinatesMutex.Lock() - if coords.SmallerThanOrEquals(this.applier.CurrentCoordinates) { - id, err := this.applier.WriteCheckpoint(chk) - chk.Id = id - this.applier.CurrentCoordinatesMutex.Unlock() - return chk, err - } + if err := ctx.Err(); err != nil { + return nil, err + } + this.applier.CurrentCoordinatesMutex.Lock() + if coords.SmallerThanOrEquals(this.applier.CurrentCoordinates) { + id, err := this.applier.WriteCheckpoint(chk) + chk.Id = id this.applier.CurrentCoordinatesMutex.Unlock() - time.Sleep(500 * time.Millisecond) + return chk, err } + this.applier.CurrentCoordinatesMutex.Unlock() + time.Sleep(500 * time.Millisecond) } } @@ -1649,6 +1737,9 @@ func (this *Migrator) executeWriteFuncs() error { return nil } for { + if err := this.checkAbort(); err != nil { + return err + } if atomic.LoadInt64(&this.finishedMigrating) > 0 { return nil } diff --git a/go/logic/migrator_test.go b/go/logic/migrator_test.go index c4fd49233..42b6fed37 100644 --- a/go/logic/migrator_test.go +++ b/go/logic/migrator_test.go @@ -931,3 +931,284 @@ func (suite *MigratorTestSuite) TestRevert() { func TestMigrator(t *testing.T) { suite.Run(t, new(MigratorTestSuite)) } + +func TestPanicAbort_PropagatesError(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Send an error to PanicAbort + testErr := errors.New("test abort error") + go func() { + migrationContext.PanicAbort <- testErr + }() + + // Wait a bit for error to be processed + time.Sleep(100 * time.Millisecond) + + // Verify error was stored + got := migrationContext.GetAbortError() + if got != testErr { //nolint:errorlint // Testing pointer equality for sentinel error + t.Errorf("Expected error %v, got %v", testErr, got) + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success - context was cancelled + default: + t.Error("Expected context to be cancelled") + } +} + +func TestPanicAbort_FirstErrorWins(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Send first error + err1 := errors.New("first error") + go func() { + migrationContext.PanicAbort <- err1 + }() + + // Wait for first error to be processed + time.Sleep(50 * time.Millisecond) + + // Try to send second error (should be ignored) + err2 := errors.New("second error") + migrationContext.SetAbortError(err2) + + // Verify only first error is stored + got := migrationContext.GetAbortError() + if got != err1 { //nolint:errorlint // Testing pointer equality for sentinel error + t.Errorf("Expected first error %v, got %v", err1, got) + } +} + +func TestAbort_AfterRowCopy(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Give listenOnPanicAbort time to start + time.Sleep(20 * time.Millisecond) + + // Simulate row copy error by sending to rowCopyComplete in a goroutine + // (unbuffered channel, so send must be async) + testErr := errors.New("row copy failed") + go func() { + migrator.rowCopyComplete <- testErr + }() + + // Consume the error (simulating what Migrate() does) + // This is a blocking call that waits for the error + migrator.consumeRowCopyComplete() + + // Wait for the error to be processed by listenOnPanicAbort + time.Sleep(50 * time.Millisecond) + + // Check that error was stored + if got := migrationContext.GetAbortError(); got == nil { + t.Fatal("Expected abort error to be stored after row copy error") + } else if got.Error() != "row copy failed" { + t.Errorf("Expected 'row copy failed', got %v", got) + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success + case <-time.After(1 * time.Second): + t.Error("Expected context to be cancelled after row copy error") + } +} + +func TestAbort_DuringInspection(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Simulate error during inspection phase + testErr := errors.New("inspection failed") + go func() { + time.Sleep(10 * time.Millisecond) + select { + case migrationContext.PanicAbort <- testErr: + case <-migrationContext.GetContext().Done(): + } + }() + + // Wait for abort to be processed + time.Sleep(50 * time.Millisecond) + + // Call checkAbort (simulating what Migrate() does after initiateInspector) + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error after abort during inspection") + } + + if err.Error() != "inspection failed" { + t.Errorf("Expected 'inspection failed', got %v", err) + } +} + +func TestAbort_DuringStreaming(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Simulate error from streaming goroutine + testErr := errors.New("streaming error") + go func() { + time.Sleep(10 * time.Millisecond) + // Use select pattern like actual code does + select { + case migrationContext.PanicAbort <- testErr: + case <-migrationContext.GetContext().Done(): + } + }() + + // Wait for abort to be processed + time.Sleep(50 * time.Millisecond) + + // Verify error stored and context cancelled + if got := migrationContext.GetAbortError(); got == nil { + t.Fatal("Expected abort error to be stored") + } else if got.Error() != "streaming error" { + t.Errorf("Expected 'streaming error', got %v", got) + } + + // Verify checkAbort catches it + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error after streaming abort") + } +} + +func TestRetryExhaustion_TriggersAbort(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.SetDefaultNumRetries(2) // Only 2 retries + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Operation that always fails + callCount := 0 + operation := func() error { + callCount++ + return errors.New("persistent failure") + } + + // Call retryOperation (with notFatalHint=false so it sends to PanicAbort) + err := migrator.retryOperation(operation) + + // Should have called operation MaxRetries times + if callCount != 2 { + t.Errorf("Expected 2 retry attempts, got %d", callCount) + } + + // Should return the error + if err == nil { + t.Fatal("Expected retryOperation to return error") + } + + // Wait for abort to be processed + time.Sleep(100 * time.Millisecond) + + // Verify error was sent to PanicAbort and stored + if got := migrationContext.GetAbortError(); got == nil { + t.Error("Expected abort error to be stored after retry exhaustion") + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success + default: + t.Error("Expected context to be cancelled after retry exhaustion") + } +} + +func TestRevert_AbortsOnError(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrationContext.Revert = true + migrationContext.OldTableName = "_test_del" + migrationContext.OriginalTableName = "test" + migrationContext.DatabaseName = "testdb" + migrator := NewMigrator(migrationContext, "1.0.0") + + // Start listenOnPanicAbort + go migrator.listenOnPanicAbort() + + // Simulate error during revert + testErr := errors.New("revert failed") + go func() { + time.Sleep(10 * time.Millisecond) + select { + case migrationContext.PanicAbort <- testErr: + case <-migrationContext.GetContext().Done(): + } + }() + + // Wait for abort to be processed + time.Sleep(50 * time.Millisecond) + + // Verify checkAbort catches it + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error during revert") + } + + if err.Error() != "revert failed" { + t.Errorf("Expected 'revert failed', got %v", err) + } + + // Verify context was cancelled + ctx := migrationContext.GetContext() + select { + case <-ctx.Done(): + // Success + default: + t.Error("Expected context to be cancelled during revert abort") + } +} + +func TestCheckAbort_ReturnsNilWhenNoError(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // No error has occurred + err := migrator.checkAbort() + if err != nil { + t.Errorf("Expected no error, got %v", err) + } +} + +func TestCheckAbort_DetectsContextCancellation(t *testing.T) { + migrationContext := base.NewMigrationContext() + migrator := NewMigrator(migrationContext, "1.0.0") + + // Cancel context directly (without going through PanicAbort) + migrationContext.CancelContext() + + // checkAbort should detect the cancellation + err := migrator.checkAbort() + if err == nil { + t.Fatal("Expected checkAbort to return error when context is cancelled") + } +} diff --git a/go/logic/server.go b/go/logic/server.go index 45e5b2bd4..74097acb7 100644 --- a/go/logic/server.go +++ b/go/logic/server.go @@ -450,7 +450,8 @@ help # This message return NoPrintStatusRule, err } err := fmt.Errorf("User commanded 'panic'. The migration will be aborted without cleanup. Please drop the gh-ost tables before trying again.") - this.migrationContext.PanicAbort <- err + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, err) return NoPrintStatusRule, err } default: diff --git a/go/logic/streamer.go b/go/logic/streamer.go index 63afc3f3d..1c2635138 100644 --- a/go/logic/streamer.go +++ b/go/logic/streamer.go @@ -186,7 +186,12 @@ func (this *EventsStreamer) StreamEvents(canStopStreaming func() bool) error { // The next should block and execute forever, unless there's a serious error. var successiveFailures int var reconnectCoords mysql.BinlogCoordinates + ctx := this.migrationContext.GetContext() for { + // Check for context cancellation each iteration + if err := ctx.Err(); err != nil { + return err + } if canStopStreaming() { return nil } diff --git a/go/logic/throttler.go b/go/logic/throttler.go index 7a6534baf..1ca40f957 100644 --- a/go/logic/throttler.go +++ b/go/logic/throttler.go @@ -362,7 +362,9 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { // Regardless of throttle, we take opportunity to check for panic-abort if this.migrationContext.PanicFlagFile != "" { if base.FileExists(this.migrationContext.PanicFlagFile) { - this.migrationContext.PanicAbort <- fmt.Errorf("Found panic-file %s. Aborting without cleanup", this.migrationContext.PanicFlagFile) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("Found panic-file %s. Aborting without cleanup", this.migrationContext.PanicFlagFile)) + return nil } } @@ -385,7 +387,9 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { } if criticalLoadMet && this.migrationContext.CriticalLoadIntervalMilliseconds == 0 { - this.migrationContext.PanicAbort <- fmt.Errorf("critical-load met: %s=%d, >=%d", variableName, value, threshold) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("critical-load met: %s=%d, >=%d", variableName, value, threshold)) + return nil } if criticalLoadMet && this.migrationContext.CriticalLoadIntervalMilliseconds > 0 { this.migrationContext.Log.Errorf("critical-load met once: %s=%d, >=%d. Will check again in %d millis", variableName, value, threshold, this.migrationContext.CriticalLoadIntervalMilliseconds) @@ -393,7 +397,8 @@ func (this *Throttler) collectGeneralThrottleMetrics() error { timer := time.NewTimer(time.Millisecond * time.Duration(this.migrationContext.CriticalLoadIntervalMilliseconds)) <-timer.C if criticalLoadMetAgain, variableName, value, threshold, _ := this.criticalLoadIsMet(); criticalLoadMetAgain { - this.migrationContext.PanicAbort <- fmt.Errorf("critical-load met again after %d millis: %s=%d, >=%d", this.migrationContext.CriticalLoadIntervalMilliseconds, variableName, value, threshold) + // Use helper to prevent deadlock if listenOnPanicAbort already exited + _ = base.SendWithContext(this.migrationContext.GetContext(), this.migrationContext.PanicAbort, fmt.Errorf("critical-load met again after %d millis: %s=%d, >=%d", this.migrationContext.CriticalLoadIntervalMilliseconds, variableName, value, threshold)) } }() } @@ -481,7 +486,16 @@ func (this *Throttler) initiateThrottlerChecks() { ticker := time.NewTicker(100 * time.Millisecond) defer ticker.Stop() - for range ticker.C { + for { + // Check for context cancellation each iteration + ctx := this.migrationContext.GetContext() + select { + case <-ctx.Done(): + return + case <-ticker.C: + // Process throttle check + } + if atomic.LoadInt64(&this.finishedMigrating) > 0 { return }