Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 118 additions & 33 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ type HttpHeader struct {
value string
}

type Config struct {
version bool
poll bool
templates []string
stdoutTails []string
stderrTails []string
headersFlag []string
delims []string
headers []HttpHeader
urls []url.URL
waits []string
waitTimeout time.Duration
waitRetryInterval time.Duration
noOverwrite bool
args []string
}

func (c *Context) Env() map[string]string {
env := make(map[string]string, len(os.Environ()))
for _, i := range os.Environ() {
Expand Down Expand Up @@ -217,11 +234,9 @@ Arguments:
println(`For more information, see https://github.com/jwilder/dockerize`)
}

func main() {

func registerFlags() {
flag.BoolVar(&version, "version", false, "show version")
flag.BoolVar(&poll, "poll", false, "enable polling")

flag.Var(&templatesFlag, "template", "Template (/template:/dest). Can be passed multiple times. Does also support directories")
flag.BoolVar(&noOverwriteFlag, "no-overwrite", false, "Do not overwrite destination file if it already exists.")
flag.Var(&stdoutTailFlag, "stdout", "Tails a file to stdout. Can be passed multiple times")
Expand All @@ -231,55 +246,90 @@ func main() {
flag.Var(&waitFlag, "wait", "Host (tcp/tcp4/tcp6/http/https/unix/file) to wait for before this container starts. Can be passed multiple times. e.g. tcp://db:5432")
flag.DurationVar(&waitTimeoutFlag, "timeout", 10*time.Second, "Host wait timeout")
flag.DurationVar(&waitRetryInterval, "wait-retry-interval", defaultWaitRetryInterval, "Duration to wait before retrying")

flag.Usage = usage
flag.Parse()
}

if version {
fmt.Println(buildVersion)
return
func parseDelimiters(value string) ([]string, error) {
if value == "" {
return nil, nil
}

if flag.NArg() == 0 && flag.NFlag() == 0 {
usage()
os.Exit(1)
}

if delimsFlag != "" {
delims = strings.Split(delimsFlag, ":")
if len(delims) != 2 {
log.Fatalf("bad delimiters argument: %s. expected \"left:right\"", delimsFlag)
}
delims := strings.Split(value, ":")
if len(delims) != 2 {
return nil, fmt.Errorf("bad delimiters argument: %s. expected \"left:right\"", value)
}
return delims, nil
}

for _, host := range waitFlag {
func parseWaitURLs(hosts hostFlagsVar) ([]url.URL, error) {
urls := make([]url.URL, 0, len(hosts))
for _, host := range hosts {
u, err := url.Parse(host)
if err != nil {
log.Fatalf("bad hostname provided: %s. %s", host, err.Error())
return nil, fmt.Errorf("bad hostname provided: %s. %s", host, err.Error())
}
urls = append(urls, *u)
}
return urls, nil
}

for _, h := range headersFlag {
//validate headers need -wait options
if len(waitFlag) == 0 {
log.Fatalf("-wait-http-header \"%s\" provided with no -wait option", h)
func parseHeaders(values []string, waits hostFlagsVar) ([]HttpHeader, error) {
headers := make([]HttpHeader, 0, len(values))
for _, h := range values {
if len(waits) == 0 {
return nil, fmt.Errorf("-wait-http-header \"%s\" provided with no -wait option", h)
}

const errMsg = "bad HTTP Headers argument: %s. expected \"headerName: headerValue\""
if strings.Contains(h, ":") {
parts := strings.SplitN(h, ":", 2)
if len(parts) != 2 {
log.Fatalf(errMsg, h)
return nil, fmt.Errorf(errMsg, h)
}
headers = append(headers, HttpHeader{name: strings.TrimSpace(parts[0]), value: strings.TrimSpace(parts[1])})
} else {
log.Fatalf(errMsg, h)
continue
}
return nil, fmt.Errorf(errMsg, h)
}
return headers, nil
}

func parseConfigFromFlags() (Config, error) {
parsedDelims, err := parseDelimiters(delimsFlag)
if err != nil {
return Config{}, err
}

parsedURLs, err := parseWaitURLs(waitFlag)
if err != nil {
return Config{}, err
}

for _, t := range templatesFlag {
parsedHeaders, err := parseHeaders(headersFlag, waitFlag)
if err != nil {
return Config{}, err
}

return Config{
version: version,
poll: poll,
templates: templatesFlag,
stdoutTails: stdoutTailFlag,
stderrTails: stderrTailFlag,
headersFlag: headersFlag,
delims: parsedDelims,
headers: parsedHeaders,
urls: parsedURLs,
waits: waitFlag,
waitTimeout: waitTimeoutFlag,
waitRetryInterval: waitRetryInterval,
noOverwrite: noOverwriteFlag,
args: flag.Args(),
}, nil
}

func processTemplates(templates []string) {
for _, t := range templates {
template, dest := t, ""
if strings.Contains(t, ":") {
parts := strings.SplitN(t, ":", 2)
Expand All @@ -299,17 +349,16 @@ func main() {
generateFile(template, dest)
}
}
}

waitForDependencies()

// Setup context
ctx, cancel = context.WithCancel(context.Background())

func startCommand(ctx context.Context, cancel context.CancelFunc) {
if flag.NArg() > 0 {
wg.Add(1)
go runCmd(ctx, cancel, flag.Arg(0), flag.Args()[1:]...)
}
}

func startTailers(ctx context.Context) {
for _, out := range stdoutTailFlag {
wg.Add(1)
go tailFile(ctx, out, poll, os.Stdout)
Expand All @@ -319,6 +368,42 @@ func main() {
wg.Add(1)
go tailFile(ctx, err, poll, os.Stderr)
}
}

func main() {
registerFlags()
flag.Parse()

config, err := parseConfigFromFlags()
if err != nil {
log.Fatal(err)
}

delims = config.delims
headers = config.headers
urls = config.urls
waitFlag = config.waits
waitTimeoutFlag = config.waitTimeout
waitRetryInterval = config.waitRetryInterval
noOverwriteFlag = config.noOverwrite
poll = config.poll

if config.version {
fmt.Println(buildVersion)
return
}

if flag.NArg() == 0 && flag.NFlag() == 0 {
usage()
os.Exit(1)
}

processTemplates(config.templates)
waitForDependencies()

ctx, cancel = context.WithCancel(context.Background())
startCommand(ctx, cancel)
startTailers(ctx)

wg.Wait()
}
120 changes: 120 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,126 @@ func TestLoop(t *testing.T) {
assert.Error(t, err)
}

func TestParseDelimiters(t *testing.T) {
tests := []struct {
name string
input string
want []string
wantErr string
}{
{name: "empty", input: "", want: nil},
{name: "valid", input: "{{:}}", want: []string{"{{", "}}"}},
{name: "valid with spaces", input: "<% : %>", want: []string{"<% ", " %>"}},
{name: "missing separator", input: "{{}}", wantErr: "bad delimiters argument: {{}}. expected \"left:right\""},
{name: "too many separators", input: "a:b:c", wantErr: "bad delimiters argument: a:b:c. expected \"left:right\""},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseDelimiters(tt.input)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
assert.Nil(t, got)
return
}

assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func TestParseWaitURLs(t *testing.T) {
tests := []struct {
name string
input hostFlagsVar
want []url.URL
wantErr string
}{
{name: "empty", input: nil, want: []url.URL{}},
{
name: "multiple urls",
input: hostFlagsVar{"tcp://db:5432", "http://web:8080/health", "file:///tmp/ready"},
want: []url.URL{
{Scheme: "tcp", Host: "db:5432"},
{Scheme: "http", Host: "web:8080", Path: "/health"},
{Scheme: "file", Path: "/tmp/ready"},
},
},
{name: "invalid escape", input: hostFlagsVar{"http://example.com/%zz"}, wantErr: "bad hostname provided: http://example.com/%zz. parse \"http://example.com/%zz\": invalid URL escape \"%zz\""},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseWaitURLs(tt.input)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
assert.Nil(t, got)
return
}

assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func TestParseHeaders(t *testing.T) {
tests := []struct {
name string
values []string
waits hostFlagsVar
want []HttpHeader
wantErr string
}{
{name: "empty", values: nil, waits: hostFlagsVar{"http://example.com"}, want: []HttpHeader{}},
{
name: "valid headers",
values: []string{"Accept-Encoding: gzip", "X-Test: value:with:colon", "Authorization:Bearer token"},
waits: hostFlagsVar{"http://example.com"},
want: []HttpHeader{
{name: "Accept-Encoding", value: "gzip"},
{name: "X-Test", value: "value:with:colon"},
{name: "Authorization", value: "Bearer token"},
},
},
{name: "header without wait", values: []string{"Accept: gzip"}, wantErr: "-wait-http-header \"Accept: gzip\" provided with no -wait option"},
{name: "missing colon", values: []string{"Accept gzip"}, waits: hostFlagsVar{"http://example.com"}, wantErr: "bad HTTP Headers argument: Accept gzip. expected \"headerName: headerValue\""},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseHeaders(tt.values, tt.waits)
if tt.wantErr != "" {
assert.EqualError(t, err, tt.wantErr)
assert.Nil(t, got)
return
}

assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func TestProcessTemplatesFileAndDirectoryArguments(t *testing.T) {
fileTemplateDir := t.TempDir()
filePath := filepath.Join(fileTemplateDir, "source.tmpl")
fileDest := filepath.Join(fileTemplateDir, "out.txt")
assert.NoError(t, os.WriteFile(filePath, []byte("hello"), 0o644))

dirPath := t.TempDir()
assert.NoError(t, os.WriteFile(filepath.Join(dirPath, "child.tmpl"), []byte("hello"), 0o644))
dirDest := t.TempDir()

processTemplates([]string{filePath + ":" + fileDest, dirPath + ":" + dirDest})

_, err := os.Stat(fileDest)
assert.NoError(t, err)
_, err = os.Stat(filepath.Join(dirDest, "child.tmpl"))
assert.NoError(t, err)
}

func TestWaitForSocketUsesPassedTimeoutForDial(t *testing.T) {
oldDialTimeout := dialTimeout
oldTimeout := waitTimeoutFlag
Expand Down