From 699f15796b3220b891a7fcaaef736a8c26dbd992 Mon Sep 17 00:00:00 2001 From: Jason Wilder Date: Fri, 12 Jun 2026 09:43:11 +0000 Subject: [PATCH 1/3] Refactor main.go to extract flag setup and parsed runtime state into sma --- main.go | 146 ++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 111 insertions(+), 35 deletions(-) diff --git a/main.go b/main.go index 2c0dbd6..4936e33 100644 --- a/main.go +++ b/main.go @@ -30,6 +30,23 @@ type HttpHeader struct { value string } +type Config struct { + version bool + poll bool + templates []string + stdoutTails []string + stderrTails []string + headersFlag []string + delims []string + headers []HttpHeader + urls []url.URL + waits []string + waitTimeout time.Duration + waitRetryInterval time.Duration + noOverwrite bool + args []string +} + func (c *Context) Env() map[string]string { env := make(map[string]string, len(os.Environ())) for _, i := range os.Environ() { @@ -217,11 +234,9 @@ Arguments: println(`For more information, see https://github.com/jwilder/dockerize`) } -func main() { - +func registerFlags() { flag.BoolVar(&version, "version", false, "show version") flag.BoolVar(&poll, "poll", false, "enable polling") - flag.Var(&templatesFlag, "template", "Template (/template:/dest). Can be passed multiple times. Does also support directories") flag.BoolVar(&noOverwriteFlag, "no-overwrite", false, "Do not overwrite destination file if it already exists.") flag.Var(&stdoutTailFlag, "stdout", "Tails a file to stdout. Can be passed multiple times") @@ -231,55 +246,117 @@ func main() { flag.Var(&waitFlag, "wait", "Host (tcp/tcp4/tcp6/http/https/unix/file) to wait for before this container starts. Can be passed multiple times. e.g. tcp://db:5432") flag.DurationVar(&waitTimeoutFlag, "timeout", 10*time.Second, "Host wait timeout") flag.DurationVar(&waitRetryInterval, "wait-retry-interval", defaultWaitRetryInterval, "Duration to wait before retrying") - flag.Usage = usage - flag.Parse() - - if version { - fmt.Println(buildVersion) - return - } +} - if flag.NArg() == 0 && flag.NFlag() == 0 { - usage() - os.Exit(1) +func parseDelimiters(value string) ([]string, error) { + if value == "" { + return nil, nil } - if delimsFlag != "" { - delims = strings.Split(delimsFlag, ":") - if len(delims) != 2 { - log.Fatalf("bad delimiters argument: %s. expected \"left:right\"", delimsFlag) - } + delims := strings.Split(value, ":") + if len(delims) != 2 { + return nil, fmt.Errorf("bad delimiters argument: %s. expected \"left:right\"", value) } + return delims, nil +} - for _, host := range waitFlag { +func parseWaitURLs(hosts hostFlagsVar) ([]url.URL, error) { + urls := make([]url.URL, 0, len(hosts)) + for _, host := range hosts { u, err := url.Parse(host) if err != nil { - log.Fatalf("bad hostname provided: %s. %s", host, err.Error()) + return nil, fmt.Errorf("bad hostname provided: %s. %s", host, err.Error()) } urls = append(urls, *u) } + return urls, nil +} - for _, h := range headersFlag { - //validate headers need -wait options - if len(waitFlag) == 0 { - log.Fatalf("-wait-http-header \"%s\" provided with no -wait option", h) +func parseHeaders(values []string, waits hostFlagsVar) ([]HttpHeader, error) { + headers := make([]HttpHeader, 0, len(values)) + for _, h := range values { + if len(waits) == 0 { + return nil, fmt.Errorf("-wait-http-header \"%s\" provided with no -wait option", h) } const errMsg = "bad HTTP Headers argument: %s. expected \"headerName: headerValue\"" if strings.Contains(h, ":") { parts := strings.SplitN(h, ":", 2) if len(parts) != 2 { - log.Fatalf(errMsg, h) + return nil, fmt.Errorf(errMsg, h) } headers = append(headers, HttpHeader{name: strings.TrimSpace(parts[0]), value: strings.TrimSpace(parts[1])}) - } else { - log.Fatalf(errMsg, h) + continue } + return nil, fmt.Errorf(errMsg, h) + } + return headers, nil +} +func parseConfigFromFlags() (Config, error) { + parsedDelims, err := parseDelimiters(delimsFlag) + if err != nil { + return Config{}, err } - for _, t := range templatesFlag { + parsedURLs, err := parseWaitURLs(waitFlag) + if err != nil { + return Config{}, err + } + + parsedHeaders, err := parseHeaders(headersFlag, waitFlag) + if err != nil { + return Config{}, err + } + + return Config{ + version: version, + poll: poll, + templates: templatesFlag, + stdoutTails: stdoutTailFlag, + stderrTails: stderrTailFlag, + headersFlag: headersFlag, + delims: parsedDelims, + headers: parsedHeaders, + urls: parsedURLs, + waits: waitFlag, + waitTimeout: waitTimeoutFlag, + waitRetryInterval: waitRetryInterval, + noOverwrite: noOverwriteFlag, + args: flag.Args(), + }, nil +} + +func main() { + registerFlags() + flag.Parse() + + config, err := parseConfigFromFlags() + if err != nil { + log.Fatal(err) + } + + if config.version { + fmt.Println(buildVersion) + return + } + + if flag.NArg() == 0 && flag.NFlag() == 0 { + usage() + os.Exit(1) + } + + delims = config.delims + headers = config.headers + urls = config.urls + waitFlag = config.waits + waitTimeoutFlag = config.waitTimeout + waitRetryInterval = config.waitRetryInterval + noOverwriteFlag = config.noOverwrite + poll = config.poll + + for _, t := range config.templates { template, dest := t, "" if strings.Contains(t, ":") { parts := strings.SplitN(t, ":", 2) @@ -302,22 +379,21 @@ func main() { waitForDependencies() - // Setup context ctx, cancel = context.WithCancel(context.Background()) - if flag.NArg() > 0 { + if len(config.args) > 0 { wg.Add(1) - go runCmd(ctx, cancel, flag.Arg(0), flag.Args()[1:]...) + go runCmd(ctx, cancel, config.args[0], config.args[1:]...) } - for _, out := range stdoutTailFlag { + for _, out := range config.stdoutTails { wg.Add(1) - go tailFile(ctx, out, poll, os.Stdout) + go tailFile(ctx, out, config.poll, os.Stdout) } - for _, err := range stderrTailFlag { + for _, err := range config.stderrTails { wg.Add(1) - go tailFile(ctx, err, poll, os.Stderr) + go tailFile(ctx, err, config.poll, os.Stderr) } wg.Wait() From 0379e08dd2106a2ebf80e31fbd105ed5d7ac44e7 Mon Sep 17 00:00:00 2001 From: Jason Wilder Date: Fri, 12 Jun 2026 09:43:32 +0000 Subject: [PATCH 2/3] Continue the main.go complexity reduction by extracting the remaining ex --- main.go | 87 +++++++++++++++++++++++++++++++-------------------------- 1 file changed, 48 insertions(+), 39 deletions(-) diff --git a/main.go b/main.go index 4936e33..f4badbd 100644 --- a/main.go +++ b/main.go @@ -328,35 +328,8 @@ func parseConfigFromFlags() (Config, error) { }, nil } -func main() { - registerFlags() - flag.Parse() - - config, err := parseConfigFromFlags() - if err != nil { - log.Fatal(err) - } - - if config.version { - fmt.Println(buildVersion) - return - } - - if flag.NArg() == 0 && flag.NFlag() == 0 { - usage() - os.Exit(1) - } - - delims = config.delims - headers = config.headers - urls = config.urls - waitFlag = config.waits - waitTimeoutFlag = config.waitTimeout - waitRetryInterval = config.waitRetryInterval - noOverwriteFlag = config.noOverwrite - poll = config.poll - - for _, t := range config.templates { +func processTemplates(templates []string) { + for _, t := range templates { template, dest := t, "" if strings.Contains(t, ":") { parts := strings.SplitN(t, ":", 2) @@ -376,25 +349,61 @@ func main() { generateFile(template, dest) } } +} - waitForDependencies() - - ctx, cancel = context.WithCancel(context.Background()) - - if len(config.args) > 0 { +func startCommand(ctx context.Context, cancel context.CancelFunc) { + if flag.NArg() > 0 { wg.Add(1) - go runCmd(ctx, cancel, config.args[0], config.args[1:]...) + go runCmd(ctx, cancel, flag.Arg(0), flag.Args()[1:]...) } +} - for _, out := range config.stdoutTails { +func startTailers(ctx context.Context) { + for _, out := range stdoutTailFlag { wg.Add(1) - go tailFile(ctx, out, config.poll, os.Stdout) + go tailFile(ctx, out, poll, os.Stdout) } - for _, err := range config.stderrTails { + for _, err := range stderrTailFlag { wg.Add(1) - go tailFile(ctx, err, config.poll, os.Stderr) + go tailFile(ctx, err, poll, os.Stderr) } +} + +func main() { + registerFlags() + flag.Parse() + + config, err := parseConfigFromFlags() + if err != nil { + log.Fatal(err) + } + + delims = config.delims + headers = config.headers + urls = config.urls + waitFlag = config.waits + waitTimeoutFlag = config.waitTimeout + waitRetryInterval = config.waitRetryInterval + noOverwriteFlag = config.noOverwrite + poll = config.poll + + if config.version { + fmt.Println(buildVersion) + return + } + + if flag.NArg() == 0 && flag.NFlag() == 0 { + usage() + os.Exit(1) + } + + processTemplates(config.templates) + waitForDependencies() + + ctx, cancel = context.WithCancel(context.Background()) + startCommand(ctx, cancel) + startTailers(ctx) wg.Wait() } From 382c14303068a1d1b789296a798a2feacfb1747b Mon Sep 17 00:00:00 2001 From: Jason Wilder Date: Fri, 12 Jun 2026 09:44:27 +0000 Subject: [PATCH 3/3] Expand main_test.go with targeted unit tests for the new main.go helper --- main_test.go | 120 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/main_test.go b/main_test.go index ba83291..e1a6982 100644 --- a/main_test.go +++ b/main_test.go @@ -181,6 +181,126 @@ func TestLoop(t *testing.T) { assert.Error(t, err) } +func TestParseDelimiters(t *testing.T) { + tests := []struct { + name string + input string + want []string + wantErr string + }{ + {name: "empty", input: "", want: nil}, + {name: "valid", input: "{{:}}", want: []string{"{{", "}}"}}, + {name: "valid with spaces", input: "<% : %>", want: []string{"<% ", " %>"}}, + {name: "missing separator", input: "{{}}", wantErr: "bad delimiters argument: {{}}. expected \"left:right\""}, + {name: "too many separators", input: "a:b:c", wantErr: "bad delimiters argument: a:b:c. expected \"left:right\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseDelimiters(tt.input) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + assert.Nil(t, got) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseWaitURLs(t *testing.T) { + tests := []struct { + name string + input hostFlagsVar + want []url.URL + wantErr string + }{ + {name: "empty", input: nil, want: []url.URL{}}, + { + name: "multiple urls", + input: hostFlagsVar{"tcp://db:5432", "http://web:8080/health", "file:///tmp/ready"}, + want: []url.URL{ + {Scheme: "tcp", Host: "db:5432"}, + {Scheme: "http", Host: "web:8080", Path: "/health"}, + {Scheme: "file", Path: "/tmp/ready"}, + }, + }, + {name: "invalid escape", input: hostFlagsVar{"http://example.com/%zz"}, wantErr: "bad hostname provided: http://example.com/%zz. parse \"http://example.com/%zz\": invalid URL escape \"%zz\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseWaitURLs(tt.input) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + assert.Nil(t, got) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestParseHeaders(t *testing.T) { + tests := []struct { + name string + values []string + waits hostFlagsVar + want []HttpHeader + wantErr string + }{ + {name: "empty", values: nil, waits: hostFlagsVar{"http://example.com"}, want: []HttpHeader{}}, + { + name: "valid headers", + values: []string{"Accept-Encoding: gzip", "X-Test: value:with:colon", "Authorization:Bearer token"}, + waits: hostFlagsVar{"http://example.com"}, + want: []HttpHeader{ + {name: "Accept-Encoding", value: "gzip"}, + {name: "X-Test", value: "value:with:colon"}, + {name: "Authorization", value: "Bearer token"}, + }, + }, + {name: "header without wait", values: []string{"Accept: gzip"}, wantErr: "-wait-http-header \"Accept: gzip\" provided with no -wait option"}, + {name: "missing colon", values: []string{"Accept gzip"}, waits: hostFlagsVar{"http://example.com"}, wantErr: "bad HTTP Headers argument: Accept gzip. expected \"headerName: headerValue\""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := parseHeaders(tt.values, tt.waits) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + assert.Nil(t, got) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestProcessTemplatesFileAndDirectoryArguments(t *testing.T) { + fileTemplateDir := t.TempDir() + filePath := filepath.Join(fileTemplateDir, "source.tmpl") + fileDest := filepath.Join(fileTemplateDir, "out.txt") + assert.NoError(t, os.WriteFile(filePath, []byte("hello"), 0o644)) + + dirPath := t.TempDir() + assert.NoError(t, os.WriteFile(filepath.Join(dirPath, "child.tmpl"), []byte("hello"), 0o644)) + dirDest := t.TempDir() + + processTemplates([]string{filePath + ":" + fileDest, dirPath + ":" + dirDest}) + + _, err := os.Stat(fileDest) + assert.NoError(t, err) + _, err = os.Stat(filepath.Join(dirDest, "child.tmpl")) + assert.NoError(t, err) +} + func TestWaitForSocketUsesPassedTimeoutForDial(t *testing.T) { oldDialTimeout := dialTimeout oldTimeout := waitTimeoutFlag