diff --git a/parser/expandpath_test.go b/parser/expandpath_test.go deleted file mode 100644 index 845f919cf..000000000 --- a/parser/expandpath_test.go +++ /dev/null @@ -1,123 +0,0 @@ -package parser - -import ( - "os" - "os/user" - "path/filepath" - "runtime" - "testing" -) - -func TestExpandPath(t *testing.T) { - mockCurrentUser := func() (*user.User, error) { - return &user.User{ - Username: "testuser", - HomeDir: func() string { - if os.PathSeparator == '\\' { - return filepath.FromSlash("D:/home/testuser") - } - return "/home/testuser" - }(), - }, nil - } - - mockLookupUser := func(username string) (*user.User, error) { - fakeUsers := map[string]string{ - "testuser": func() string { - if os.PathSeparator == '\\' { - return filepath.FromSlash("D:/home/testuser") - } - return "/home/testuser" - }(), - "anotheruser": func() string { - if os.PathSeparator == '\\' { - return filepath.FromSlash("D:/home/anotheruser") - } - return "/home/anotheruser" - }(), - } - - if homeDir, ok := fakeUsers[username]; ok { - return &user.User{ - Username: username, - HomeDir: homeDir, - }, nil - } - return nil, os.ErrNotExist - } - - pwd, err := os.Getwd() - if err != nil { - t.Fatal(err) - } - - t.Run("unix tests", func(t *testing.T) { - if runtime.GOOS == "windows" { - return - } - - tests := []struct { - path string - relativeDir string - expected string - shouldErr bool - }{ - {"~", "", "/home/testuser", false}, - {"~/myfolder/myfile.txt", "", "/home/testuser/myfolder/myfile.txt", false}, - {"~anotheruser/docs/file.txt", "", "/home/anotheruser/docs/file.txt", false}, - {"~nonexistentuser/file.txt", "", "", true}, - {"relative/path/to/file", "", filepath.Join(pwd, "relative/path/to/file"), false}, - {"/absolute/path/to/file", "", "/absolute/path/to/file", false}, - {"/absolute/path/to/file", "someotherdir/", "/absolute/path/to/file", false}, - {".", pwd, pwd, false}, - {".", "", pwd, false}, - {"somefile", "somedir", filepath.Join(pwd, "somedir", "somefile"), false}, - } - - for _, test := range tests { - result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser) - if (err != nil) != test.shouldErr { - t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr) - } - - if result != test.expected && !test.shouldErr { - t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected) - } - } - }) - - t.Run("windows tests", func(t *testing.T) { - if runtime.GOOS != "windows" { - return - } - - tests := []struct { - path string - relativeDir string - expected string - shouldErr bool - }{ - {"~", "", "D:\\home\\testuser", false}, - {"~/myfolder/myfile.txt", "", "D:\\home\\testuser\\myfolder\\myfile.txt", false}, - {"~anotheruser/docs/file.txt", "", "D:\\home\\anotheruser\\docs\\file.txt", false}, - {"~nonexistentuser/file.txt", "", "", true}, - {"relative\\path\\to\\file", "", filepath.Join(pwd, "relative\\path\\to\\file"), false}, - {"D:\\absolute\\path\\to\\file", "", "D:\\absolute\\path\\to\\file", false}, - {"D:\\absolute\\path\\to\\file", "someotherdir/", "D:\\absolute\\path\\to\\file", false}, - {".", pwd, pwd, false}, - {".", "", pwd, false}, - {"somefile", "somedir", filepath.Join(pwd, "somedir", "somefile"), false}, - } - - for _, test := range tests { - result, err := expandPathImpl(test.path, test.relativeDir, mockCurrentUser, mockLookupUser) - if (err != nil) != test.shouldErr { - t.Errorf("expandPathImpl(%q) returned error: %v, expected error: %v", test.path, err != nil, test.shouldErr) - } - - if result != test.expected && !test.shouldErr { - t.Errorf("expandPathImpl(%q) = %q, want %q", test.path, result, test.expected) - } - } - }) -} diff --git a/parser/parser.go b/parser/parser.go index c2e8f981f..bc16dd399 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -617,43 +617,42 @@ func isValidCommand(cmd string) bool { } } -func expandPathImpl(path, relativeDir string, currentUserFunc func() (*user.User, error), lookupUserFunc func(string) (*user.User, error)) (string, error) { - if filepath.IsAbs(path) || strings.HasPrefix(path, "\\") || strings.HasPrefix(path, "/") { - return filepath.Abs(path) - } else if strings.HasPrefix(path, "~") { - var homeDir string - - if path == "~" || strings.HasPrefix(path, "~/") { - // Current user's home directory - currentUser, err := currentUserFunc() - if err != nil { - return "", fmt.Errorf("failed to get current user: %w", err) - } - homeDir = currentUser.HomeDir - path = strings.TrimPrefix(path, "~") - } else { - // Specific user's home directory - parts := strings.SplitN(path[1:], "/", 2) - userInfo, err := lookupUserFunc(parts[0]) - if err != nil { - return "", fmt.Errorf("failed to find user '%s': %w", parts[0], err) - } - homeDir = userInfo.HomeDir - if len(parts) > 1 { - path = "/" + parts[1] - } else { - path = "" - } - } - - path = filepath.Join(homeDir, path) - } else { - path = filepath.Join(relativeDir, path) +func expandPath(path, dir string) (string, error) { + if filepath.IsAbs(path) { + return path, nil } - return filepath.Abs(path) -} + path, found := strings.CutPrefix(path, "~") + if !found { + // make path relative to dir + if !filepath.IsAbs(dir) { + // if dir is relative, make it absolute relative to cwd + cwd, err := os.Getwd() + if err != nil { + return "", err + } + dir = filepath.Join(cwd, dir) + } + path = filepath.Join(dir, path) + } else if filepath.IsLocal(path) { + // ~/... + // make path relative to specified user's home + split := strings.SplitN(path, "/", 2) + u, err := user.Lookup(split[0]) + if err != nil { + return "", err + } + split[0] = u.HomeDir + path = filepath.Join(split...) + } else { + // ~ or ~/... + // make path relative to current user's home + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + path = filepath.Join(home, path) + } -func expandPath(path, relativeDir string) (string, error) { - return expandPathImpl(path, relativeDir, user.Current, user.Lookup) + return filepath.Clean(path), nil } diff --git a/parser/parser_test.go b/parser/parser_test.go index 1524e890a..9de92ba21 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -8,6 +8,9 @@ import ( "fmt" "io" "os" + "os/user" + "path/filepath" + "runtime" "strings" "testing" "unicode/utf16" @@ -851,3 +854,62 @@ func TestCreateRequestFiles(t *testing.T) { } } } + +func TestExpandPath(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + t.Setenv("USERPROFILE", home) + + cwd, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + + u, err := user.Current() + if err != nil { + t.Fatal(err) + } + + volume := "" + if runtime.GOOS == "windows" { + volume = "D:" + } + + cases := []struct { + input, + dir, + want string + err error + }{ + {"~", "", home, nil}, + {"~/path/to/file", "", filepath.Join(home, filepath.ToSlash("path/to/file")), nil}, + {"~" + u.Username + "/path/to/file", "", filepath.Join(u.HomeDir, filepath.ToSlash("path/to/file")), nil}, + {"~nonexistentuser/path/to/file", "", "", user.UnknownUserError("nonexistentuser")}, + {"relative/path/to/file", "", filepath.Join(cwd, filepath.ToSlash("relative/path/to/file")), nil}, + {volume + "/absolute/path/to/file", "", filepath.ToSlash(volume + "/absolute/path/to/file"), nil}, + {volume + "/absolute/path/to/file", filepath.ToSlash("another/path"), filepath.ToSlash(volume + "/absolute/path/to/file"), nil}, + {".", cwd, cwd, nil}, + {".", "", cwd, nil}, + {"", cwd, cwd, nil}, + {"", "", cwd, nil}, + {"file", "path/to", filepath.Join(cwd, filepath.ToSlash("path/to/file")), nil}, + } + + for _, tt := range cases { + t.Run(tt.input, func(t *testing.T) { + got, err := expandPath(tt.input, tt.dir) + // On Windows, user.Lookup does not map syscall errors to user.UnknownUserError + // so we special case the test to just check for an error. + // See https://cs.opensource.google/go/go/+/refs/tags/go1.25.1:src/os/user/lookup_windows.go;l=455 + if runtime.GOOS != "windows" && !errors.Is(err, tt.err) { + t.Fatalf("expandPath(%q) error = %v, wantErr %v", tt.input, err, tt.err) + } else if tt.err != nil && err == nil { + t.Fatal("test case expected to fail on windows") + } + + if got != tt.want { + t.Errorf("expandPath(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +}