diff --git a/reexec/example_programmatic_test.go b/reexec/example_programmatic_test.go index 54bf3c4..0a2ef05 100644 --- a/reexec/example_programmatic_test.go +++ b/reexec/example_programmatic_test.go @@ -10,15 +10,20 @@ import ( ) func init() { - reexec.Register("example-child", func() { + reexec.RegisterContext("example-child", func(ctx context.Context) error { fmt.Println("Hello from example-child entrypoint") + return nil }) } // Example_programmatic demonstrates using reexec to programmatically // re-execute the current binary. func Example_programmatic() { - if reexec.Init() { + if ok, err := reexec.DispatchContext(context.Background()); ok { + if err != nil { + fmt.Fprintln(os.Stderr, "entrypoint failed:", err) + os.Exit(1) + } // Matched a reexec entrypoint; stop normal main execution. return } diff --git a/reexec/reexec.go b/reexec/reexec.go index efc7d06..53855f0 100644 --- a/reexec/reexec.go +++ b/reexec/reexec.go @@ -80,15 +80,32 @@ import ( "github.com/moby/sys/reexec/internal/reexecoverride" ) -var entrypoints = make(map[string]func()) +var entrypoints = make(map[string]func(context.Context) error) // Register associates name with an entrypoint function to be executed when // the current binary is invoked with argv[0] equal to name. // // Register is not safe for concurrent use; entrypoints must be registered -// during program initialization, before calling [Init]. +// during program initialization, before calling [Dispatch] or [Init]. // It panics if name contains a path component or is already registered. func Register(name string, entrypoint func()) { + registerEntrypoint(name, func(context.Context) error { + entrypoint() + return nil + }) +} + +// RegisterContext is like [Register] but the entrypoint receives a context and +// reports success or failure by returning an error. +// +// RegisterContext is not safe for concurrent use; entrypoints must be +// registered during program initialization, before calling [Dispatch] or [Init]. +// It panics if name contains a path component or is already registered. +func RegisterContext(name string, entrypoint func(context.Context) error) { + registerEntrypoint(name, entrypoint) +} + +func registerEntrypoint(name string, entrypoint func(context.Context) error) { if filepath.Base(name) != name { panic(fmt.Sprintf("reexec func does not expect a path component: %q", name)) } @@ -99,18 +116,45 @@ func Register(name string, entrypoint func()) { entrypoints[name] = entrypoint } +// Dispatch checks whether the current process was invoked under a registered +// name (based on filepath.Base(os.Args[0])). +// +// If a matching entrypoint is found, it is executed and Dispatch reports a +// match with ok. If ok is false, err is always nil. If ok is true, err is the +// entrypoint's return value. +// +// In the ok=true case, the caller should stop normal main execution (typically +// by returning from main or calling os.Exit). +func Dispatch() (bool, error) { + return dispatch(context.Background()) +} + +// DispatchContext is like [Dispatch] but includes a context. +func DispatchContext(ctx context.Context) (bool, error) { + return dispatch(ctx) +} + // Init checks whether the current process was invoked under a registered name // (based on filepath.Base(os.Args[0])). // // If a matching entrypoint is found, it is executed and Init returns true. In // that case, the caller should stop normal main execution. If no match is found, // Init returns false and normal execution should continue. +// +// Init is provided for compatibility with older callers. New code should +// prefer [Dispatch] or [DispatchContext], which also return the entrypoint's +// error and allow the caller to decide how to report failures. func Init() bool { - if entrypoint, ok := entrypoints[filepath.Base(os.Args[0])]; ok { - entrypoint() - return true + ok, _ := dispatch(context.Background()) + return ok +} + +func dispatch(ctx context.Context) (bool, error) { + entrypoint, ok := entrypoints[filepath.Base(os.Args[0])] + if !ok { + return false, nil } - return false + return true, entrypoint(ctx) } // Command returns an [*exec.Cmd] configured to re-execute the current binary, diff --git a/reexec/reexec_test.go b/reexec/reexec_test.go index 75b9172..14c3dce 100644 --- a/reexec/reexec_test.go +++ b/reexec/reexec_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "os" + "os/exec" "path/filepath" "reflect" "runtime" @@ -22,35 +23,66 @@ const ( testReExec3 = "test-reexec3" ) +var errTestEntrypointFailure = errors.New("test-reexec: simulated failure") + func init() { - reexec.Register(testReExec, func() { - panic("Return Error") + reexec.RegisterContext(testReExec, func(context.Context) error { + return errTestEntrypointFailure }) - reexec.Register(testReExec2, func() { + reexec.RegisterContext(testReExec2, func(context.Context) error { var args string if len(os.Args) > 1 { args = fmt.Sprintf("(args: %#v)", os.Args[1:]) } fmt.Println("Hello", testReExec2, args) - os.Exit(0) + return nil }) reexec.Register(testReExec3, func() { fmt.Println("Hello " + testReExec3) time.Sleep(1 * time.Second) os.Exit(0) }) - if reexec.Init() { - // Make sure we exit in case re-exec didn't os.Exit on its own. + + if ok, err := reexec.DispatchContext(context.Background()); ok { + if err != nil { + _, _ = fmt.Fprintln(os.Stderr, "entrypoint failed:", err) + os.Exit(1) + } + // Make sure we exit in case a re-exec entrypoint didn't os.Exit on its own. os.Exit(0) } } func TestRegister(t *testing.T) { + type registerFunc struct { + doc string + fn func(name string) + } + + registerFuncs := []registerFunc{ + { + doc: "Register", + fn: func(name string) { + reexec.Register(name, func() {}) + }, + }, + { + doc: "RegisterContext", + fn: func(name string) { + reexec.RegisterContext(name, func(context.Context) error { return nil }) + }, + }, + } + tests := []struct { doc string name string expectedErr string }{ + { + doc: "new entrypoint", + name: "test-register-new", + }, { doc: "duplicate name", name: testReExec, @@ -63,20 +95,32 @@ func TestRegister(t *testing.T) { }, } - for _, tc := range tests { - t.Run(tc.doc, func(t *testing.T) { - defer func() { - r := recover() - if r == nil { - t.Errorf("Register() did not panic") - return + for _, rf := range registerFuncs { + for _, tc := range tests { + t.Run(rf.doc+"/"+tc.doc, func(t *testing.T) { + name := tc.name + if tc.doc == "new entrypoint" { + // Use the test name as a unique suffix; replace path separators. + name = tc.name + "-" + strings.ReplaceAll(t.Name(), "/", "-") } - if r != tc.expectedErr { - t.Errorf("got %q, want %q", r, tc.expectedErr) - } - }() - reexec.Register(tc.name, func() {}) - }) + + defer func() { + r := recover() + switch { + case tc.expectedErr == "": + if r != nil { + t.Fatalf("unexpected panic: %v", r) + } + case r == nil: + t.Fatalf("expected panic %q, got nil", tc.expectedErr) + case r != tc.expectedErr: + t.Fatalf("got %q, want %q", r, tc.expectedErr) + } + }() + + rf.fn(name) + }) + } } } @@ -85,6 +129,7 @@ func TestCommand(t *testing.T) { doc string cmdAndArgs []string expOut string + expExit int }{ { doc: "basename", @@ -101,6 +146,12 @@ func TestCommand(t *testing.T) { cmdAndArgs: []string{testReExec2, "--some-flag", "some-value", "arg1", "arg2"}, expOut: `Hello test-reexec2 (args: []string{"--some-flag", "some-value", "arg1", "arg2"})`, }, + { + doc: "entrypoint returns error", + cmdAndArgs: []string{testReExec}, + expOut: "entrypoint failed: " + errTestEntrypointFailure.Error(), + expExit: 1, + }, } for _, tc := range tests { t.Run(tc.doc, func(t *testing.T) { @@ -115,8 +166,19 @@ func TestCommand(t *testing.T) { defer func() { _ = w.Close() }() out, err := cmd.CombinedOutput() - if err != nil { - t.Errorf("Error on re-exec cmd: %v, out: %v", err, string(out)) + switch { + case tc.expExit != 0: + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected *exec.ExitError, got %T: %v", err, err) + } + if code := exitErr.ExitCode(); code != tc.expExit { + t.Errorf("got exit code %d, want %d; out: %v", code, tc.expExit, string(out)) + } + default: + if err != nil { + t.Errorf("unexpected error: %v (out: %v)", err, string(out)) + } } actual := strings.TrimSpace(string(out)) @@ -127,6 +189,110 @@ func TestCommand(t *testing.T) { } } +func TestDispatch(t *testing.T) { + type ctxKey struct{} + const want = "sentinel" + + reexec.RegisterContext("test-dispatch", func(ctx context.Context) error { + got, _ := ctx.Value(ctxKey{}).(string) + if got != want { + return fmt.Errorf("unexpected context value: got %q, want %q", got, want) + } + return nil + }) + reexec.RegisterContext("test-dispatch-canceled", func(ctx context.Context) error { + return ctx.Err() + }) + reexec.Register("test-dispatch-no-ctx", func() {}) + + tests := []struct { + name string + ctx func(t *testing.T) context.Context + dispatch func(ctx context.Context) (bool, error) + check func(t *testing.T, ok bool, err error) + }{ + { + name: "not-registered", + ctx: func(t *testing.T) context.Context { + return context.Background() + }, + dispatch: reexec.DispatchContext, + check: func(t *testing.T, ok bool, err error) { + t.Helper() + if ok { + t.Errorf("expected ok=false, got true") + } + if err != nil { + t.Errorf("expected err=nil, got %v", err) + } + }, + }, + { + name: "test-dispatch", + ctx: func(t *testing.T) context.Context { + return context.WithValue(t.Context(), ctxKey{}, want) + }, + dispatch: reexec.DispatchContext, + check: func(t *testing.T, ok bool, err error) { + t.Helper() + if !ok { + t.Errorf("expected ok=true, got false") + } + if err != nil { + t.Errorf("expected err=nil, got %v", err) + } + }, + }, + { + name: "test-dispatch-canceled", + ctx: func(t *testing.T) context.Context { + ctx, cancel := context.WithCancel(t.Context()) + cancel() + return ctx + }, + dispatch: reexec.DispatchContext, + check: func(t *testing.T, ok bool, err error) { + t.Helper() + if !ok { + t.Errorf("expected ok=true, got false") + } + if !errors.Is(err, context.Canceled) { + t.Errorf("expected context.Canceled, got %v", err) + } + }, + }, + { + name: "test-dispatch-no-ctx", + ctx: func(t *testing.T) context.Context { + return context.Background() + }, + dispatch: func(context.Context) (bool, error) { + return reexec.Dispatch() + }, + check: func(t *testing.T, ok bool, err error) { + t.Helper() + if !ok { + t.Errorf("expected ok=true, got false") + } + if err != nil { + t.Errorf("expected err=nil, got %v", err) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + argv0 := os.Args[0] + t.Cleanup(func() { os.Args[0] = argv0 }) + + os.Args[0] = tc.name + ok, err := tc.dispatch(tc.ctx(t)) + tc.check(t, ok, err) + }) + } +} + func TestCommandContext(t *testing.T) { testError := errors.New("test-error: the command was canceled") @@ -136,6 +302,7 @@ func TestCommandContext(t *testing.T) { cancel bool expOut string expError error + expExit int }{ { doc: "basename", @@ -164,6 +331,12 @@ func TestCommandContext(t *testing.T) { expOut: "Hello test-reexec3", expError: testError, }, + { + doc: "entrypoint returns error", + cmdAndArgs: []string{testReExec}, + expOut: "entrypoint failed: " + errTestEntrypointFailure.Error(), + expExit: 1, + }, } for _, tc := range tests { @@ -188,8 +361,23 @@ func TestCommandContext(t *testing.T) { cancel() } out, err := cmd.CombinedOutput() - if !errors.Is(err, tc.expError) { - t.Errorf("expected %v, got: %v", tc.expError, err) + switch { + case tc.expError != nil: + if !errors.Is(err, tc.expError) { + t.Errorf("expected %v, got: %v (out: %v)", tc.expError, err, string(out)) + } + case tc.expExit != 0: + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected *exec.ExitError, got %T: %v", err, err) + } + if code := exitErr.ExitCode(); code != tc.expExit { + t.Errorf("got exit code %d, want %d; out: %v", code, tc.expExit, string(out)) + } + default: + if err != nil { + t.Errorf("unexpected error: %v (out: %v)", err, string(out)) + } } actual := strings.TrimSpace(string(out))