diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fc72fd0df..82622527a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,14 @@ +## 0.3.1 (2026-06-11) + +* OAuth browser login: `ucloud auth login` / `ucloud auth logout` (RFC 8252 loopback auto-capture, `--no-browser` fallback) +* automatic token lifecycle: proactive refresh before expiry, reactive refresh-and-replay on auth failure (RetCode 174), flock-serialized rotation safe for concurrent processes +* atomic temp+rename writes for config/credential files +* `auth_mode` picks exactly one credential mechanism per profile; AK/SK profiles unchanged +* fix: `config update --base-url` no longer validates against the old gateway +* fix: login validates an existing project_id against the logged-in account +* fix: `init` on an OAuth profile persists the switch back to AK/SK +* token redaction across all log sinks; panic output redacted + ## 0.3.0 (2024-09-20) * support naming in batch on creating uhost diff --git a/README-CN.md b/README-CN.md index 0922dd51a8..57625238c6 100644 --- a/README-CN.md +++ b/README-CN.md @@ -172,6 +172,81 @@ $ ucloud config update --profile xxx --region cn-sh2 $ ucloud config --help ``` +## 认证方式 + +UCloud CLI 支持两种认证方式,请根据使用场景选择: + +| 使用场景 | 推荐方式 | +| --- | --- | +| 交互终端上的人类用户 | OAuth 浏览器登录:`ucloud auth login`(推荐) | +| 脚本、CI/CD 等无人值守自动化 | AK/SK profile:`ucloud init` 或 `ucloud config` | + +### 浏览器登录(OAuth) + +``` +$ ucloud auth login +``` + +执行过程: + +1. CLI 在 127.0.0.1 的临时端口上启动一个本地回调 server,并打开浏览器访问 UCloud 授权页。如果浏览器没有自动打开,复制终端打印的 URL 手动打开即可。 +2. 在浏览器中登录并授权后,浏览器会跳转到 `http://localhost:/authorization`,CLI 自动捕获授权码并展示「登录成功」页面——全程无需复制粘贴,关闭页面回到终端即可。 +3. CLI 用授权码换取 token 并保存。如果当前 profile 还没有配置 region/zone/project,会自动获取并配置默认值: + +``` +Configured default region:cn-bj2 zone:cn-bj2-02 +Configured default project:org-xxxxxx Default +Logged in as you@example.com, token valid until 18:30 +``` + +### 手工回退(本机无浏览器) + +SSH 会话或无图形界面的机器,请加 `--no-browser`: + +``` +$ ucloud auth login --no-browser +``` + +CLI 会打印授权 URL 而不是打开浏览器。在任意设备上打开该 URL,登录并授权后,浏览器会跳转到一个**无法打开的** `http://localhost:/authorization?...` 页面——**这是预期行为**。把地址栏中的完整 URL 复制下来,粘贴回终端即可。 + +默认模式下同样存在这条粘贴回退路径:如果自动捕获在 3 分钟内没有收到回调,CLI 会打印 "Automatic capture timed out. Paste the callback URL here as a fallback:" 并等待粘贴回调 URL。 + +### Token 存储与有效期 + +- Token 存储在 `~/.ucloud/credential.json`,文件权限 0600。 +- access token 有效期约 1 小时,到期后通过 refresh token 静默续期,全程无感知。续期发生在使用时:运行命令时 CLI 会先检查并刷新临期 token;若网关在命令执行中拒绝 token,也会自动刷新并重试。没有后台常驻进程。 +- refresh token 当前有效期为 7 天,且每次续期都会轮换出新的 7 天有效期——只要 7 天内用过一次 CLI,登录态就一直延续;连续 7 天未使用后,下次命令会提示重新执行 `ucloud auth login`。 +- 登录态会一直保持,直到 refresh token 在服务端过期,或执行 `ucloud auth logout`。logout 只删除当前 profile 的本地 token,不会动已存储的 AK/SK。退出 UCloud 网页控制台不影响 CLI 登录态。 + +### 一个 profile 只用一种认证方式 + +- 每个 profile 同一时刻只使用一种认证方式,可在 `ucloud config list` 的 `AuthMode` 列查看(`oauth` 为浏览器登录,空值为 AK/SK 签名)。 +- 执行 `ucloud auth login` 会把当前 profile 切换到 OAuth 模式;已有的 AK/SK 仍保留在配置中,但不再参与签名。 +- 在 OAuth 模式的 profile 上执行 `ucloud init`,会先要求确认,确认后才切回 AK/SK。 +- 命令行同时传入 `--public-key` 和 `--private-key` 时始终优先生效:该次调用使用 AK/SK 签名,与 profile 的认证模式无关。 + +### 代理 + +OAuth token 请求遵循标准的 `HTTPS_PROXY` / `HTTP_PROXY` / `NO_PROXY` 环境变量。 + +### 使用限制 + +- **OAuth 登录按机器隔离。** 不要在多台机器之间拷贝或共享 `~/.ucloud`:refresh token 每次续期都会轮换,一台机器续期会把另一台「挤」下线。多机或共享场景请使用 AK/SK profile。 +- **降级会丢失 token。** 旧版本 ucloud-cli 不认识 token 字段,重写配置文件时会静默丢弃它们。降级后请重新执行 `ucloud auth login`。 + +### 故障排查 + +| 现象 / 报错 | 处理方法 | +| --- | --- | +| `authorization code or refresh token expired or already used (each code works only once)` | 每个授权码只能使用一次且很快过期。重新执行 `ucloud auth login` 并尽快完成流程。 | +| `state mismatch: the pasted URL likely comes from a previous login attempt` | 粘贴的是上一次登录尝试的回调 URL。重新执行 `ucloud auth login`,并粘贴本次的 URL。 | +| `Login expired for profile ''` | refresh token 已失效。重新执行 `ucloud auth login`。 | +| `cannot reach oauth server ... (check network or proxy settings)` | 网络或代理问题。检查网络连通性以及 `HTTPS_PROXY` / `HTTP_PROXY` / `NO_PROXY` 设置。 | +| `'ucloud auth login' requires an interactive terminal` | 当前处于 CI 或管道(非交互)环境。OAuth 登录面向交互人类用户,自动化场景请改用 AK/SK profile。 | +| 浏览器没有自动打开 | 复制终端打印的 URL 手动打开,或改用 `--no-browser`。 | +| (手工模式)浏览器提示 localhost 页面无法打开 | 预期行为——手工模式下 CLI 并未监听该端口。复制地址栏中的完整 URL 粘贴回终端即可。 | +| (手工模式)localhost 页面显示了其他本地程序的内容 | 无害——恰好有其他程序监听了该端口。只有地址栏里的 URL 有用,复制粘贴即可。 | + ## 举例说明 用UCloud CLI在尼日利亚创建数据中心创建一台主机并绑定一个外网IP,然后配置GlobalSSH加速,加速中国大陆到目的主机的SSH登陆 diff --git a/README.md b/README.md index dfe88c9cf3..8e6740ef07 100644 --- a/README.md +++ b/README.md @@ -181,6 +181,81 @@ For more information, run: $ ucloud config --help ``` +## Authentication + +The UCloud CLI supports two ways to authenticate. Pick one based on how you use the CLI: + +| You are | Use | +| --- | --- | +| A human at an interactive terminal | OAuth browser login: `ucloud auth login` (recommended) | +| Scripts, CI/CD or other unattended automation | AK/SK profile: `ucloud init` or `ucloud config` | + +### Log in via browser (OAuth) + +``` +$ ucloud auth login +``` + +What happens: + +1. The CLI starts a temporary local callback server on an ephemeral 127.0.0.1 port and opens your browser at the UCloud authorization page. If the browser does not open, copy the printed URL and open it manually. +2. You log in and approve. The browser is redirected to `http://localhost:/authorization`, where the CLI captures the authorization code automatically and shows a "Login successful" page — no copy-paste needed. Just close the tab and return to the terminal. +3. The CLI exchanges the code for tokens and saves them. If the profile has no region/zone/project configured yet, it also fetches and configures the defaults: + +``` +Configured default region:cn-bj2 zone:cn-bj2-02 +Configured default project:org-xxxxxx Default +Logged in as you@example.com, token valid until 18:30 +``` + +### Manual fallback (no browser on this machine) + +For SSH sessions or headless machines, pass `--no-browser`: + +``` +$ ucloud auth login --no-browser +``` + +The CLI prints the authorization URL instead of opening a browser. Open it on any device, log in and approve. The browser will then be redirected to a `http://localhost:/authorization?...` page that **cannot open — this is expected**. Copy the FULL URL from the address bar and paste it back into the terminal. + +The same paste prompt is also used as a fallback in the default mode: if the automatic capture does not receive the callback within 3 minutes, the CLI prints "Automatic capture timed out. Paste the callback URL here as a fallback:" and waits for the pasted URL. + +### Token storage and lifetime + +- Tokens are stored in `~/.ucloud/credential.json` with file mode 0600. +- The access token is valid for about 1 hour and is renewed silently via the refresh token — no action needed from you. Renewal happens on use: when you run a command, the CLI refreshes the token if it is about to expire, and also recovers automatically if the gateway rejects a token mid-command. There is no background daemon. +- The refresh token is currently valid for 7 days and is replaced with a fresh one on every renewal, so any use of the CLI within that window keeps you logged in indefinitely. After 7 days without use, the next command asks you to run `ucloud auth login` again. +- You stay logged in until the refresh token expires on the server side, or until you run `ucloud auth logout`. Logout only removes the locally stored tokens of the current profile; it does not touch any stored AK/SK keys. Logging out of the UCloud web console does not affect CLI sessions. + +### One profile, one auth method + +- Each profile uses exactly one auth method at a time, shown in the `AuthMode` column of `ucloud config list` (`oauth` for browser login, empty for AK/SK signing). +- Running `ucloud auth login` switches the current profile to OAuth. Existing AK/SK keys are kept stored but no longer used for signing. +- Running `ucloud init` on an OAuth profile asks for confirmation before switching the profile back to AK/SK. +- Passing both `--public-key` and `--private-key` flags on a command always takes precedence: that invocation uses AK/SK signing regardless of the profile's auth mode. + +### Proxies + +OAuth token requests honor the standard `HTTPS_PROXY` / `HTTP_PROXY` / `NO_PROXY` environment variables. + +### Limitations + +- **OAuth login is per machine.** Do not copy or share `~/.ucloud` across machines: the refresh token rotates on every renewal, so a renewal on one machine logs the other machine out. For multi-machine or shared setups, use an AK/SK profile. +- **Downgrading drops tokens.** Older ucloud-cli versions do not know the token fields and silently drop them when rewriting the config files. After downgrading, run `ucloud auth login` again. + +### Troubleshooting + +| Symptom / message | What to do | +| --- | --- | +| `authorization code or refresh token expired or already used (each code works only once)` | Each authorization code works only once and expires quickly. Run `ucloud auth login` again and complete the flow promptly. | +| `state mismatch: the pasted URL likely comes from a previous login attempt` | You pasted a callback URL from an earlier attempt. Run `ucloud auth login` again and paste the URL from THIS attempt. | +| `Login expired for profile ''` | The refresh token is no longer valid. Run `ucloud auth login` again. | +| `cannot reach oauth server ... (check network or proxy settings)` | Network or proxy issue. Check connectivity and your `HTTPS_PROXY` / `HTTP_PROXY` / `NO_PROXY` settings. | +| `'ucloud auth login' requires an interactive terminal` | You are in CI or piping stdin. OAuth login is for interactive humans; use an AK/SK profile instead. | +| Browser did not open | Copy the URL printed in the terminal and open it manually, or use `--no-browser`. | +| (manual mode) browser shows the localhost page cannot open | Expected — the CLI is not listening in manual mode. Copy the full URL from the address bar and paste it into the terminal. | +| (manual mode) the localhost page shows unexpected content from another local program | Harmless — something else happens to listen on that port. Only the URL in the address bar matters; copy and paste it. | + ## For example I want to create a uhost in Nigeria (region: air-nigeria) and bind a public IP, and then configure GlobalSSH to accelerate efficiency of SSH service beyond China mainland. diff --git a/base/client.go b/base/client.go index d6d0745931..71bf3bf363 100644 --- a/base/client.go +++ b/base/client.go @@ -1,6 +1,10 @@ package base import ( + "encoding/json" + "fmt" + "net/url" + "github.com/ucloud/ucloud-sdk-go/private/protocol/http" ppathx "github.com/ucloud/ucloud-sdk-go/private/services/pathx" pudb "github.com/ucloud/ucloud-sdk-go/private/services/udb" @@ -20,6 +24,7 @@ import ( "github.com/ucloud/ucloud-sdk-go/services/vpc" sdk "github.com/ucloud/ucloud-sdk-go/ucloud" "github.com/ucloud/ucloud-sdk-go/ucloud/auth" + uerr "github.com/ucloud/ucloud-sdk-go/ucloud/error" "github.com/ucloud/ucloud-sdk-go/ucloud/request" ) @@ -55,26 +60,135 @@ type Client struct { ucompshare.UCompShareClient } -// NewClient will return a aggregate client -func NewClient(config *sdk.Config, credConfig *CredentialConfig) *Client { - var handler sdk.RequestHandler = func(c *sdk.Client, req request.Common) (request.Common, error) { - err := req.SetProjectId(PickResourceID(req.GetProjectId())) - return req, err - } - var injectCredHeader sdk.HttpRequestHandler = func(c *sdk.Client, req *http.HttpRequest) (*http.HttpRequest, error) { - err := req.SetHeader("Cookie", credConfig.Cookie) - if err != nil { +// newCredHeaderInjector 返回凭据头注入 handler。 +// aksk/CloudShell 行为与历史完全一致(Cookie/Csrf-Token 始终 set,含空值); +// auth_mode==oauth 时剥离 SDK 编码器无条件附加的签名参数(Credential.Apply 即使 +// 空密钥也会算出 Signature),并在 token 非空时追加 Authorization: Bearer, +// 保证 oauth 请求只携带 Bearer 一种凭据机制(凭据模型见 spec §2)。 +func newCredHeaderInjector(credConfig *CredentialConfig) sdk.HttpRequestHandler { + return func(c *sdk.Client, req *http.HttpRequest) (*http.HttpRequest, error) { + if err := req.SetHeader("Cookie", credConfig.Cookie); err != nil { return req, err } - err = req.SetHeader("Csrf-Token", credConfig.CSRFToken) - if err != nil { + if err := req.SetHeader("Csrf-Token", credConfig.CSRFToken); err != nil { return req, err } + if credConfig.AuthMode == AuthModeOAuth { + // 仅对 form-urlencoded body 剥离(JSON 编码器虽当前不可达,但 + // url.ParseQuery 对 JSON 往往"成功",重编码会悄悄毁掉 body) + if req.GetHeaderMap()[http.HeaderNameContentType] == http.MimeFormURLEncoded { + vals, err := url.ParseQuery(string(req.GetRequestBody())) + if err != nil { + // 剥不掉就明确失败:客户端报错优于网关 171 + return req, fmt.Errorf("strip signature params from oauth request failed: %w", err) + } + vals.Del("Signature") + vals.Del("PublicKey") + if err := req.SetRequestBody([]byte(vals.Encode())); err != nil { + return req, err + } + } + if credConfig.AccessToken != "" { + if err := req.SetHeader("Authorization", "Bearer "+credConfig.AccessToken); err != nil { + return req, err + } + } + } + return req, nil + } +} + +// authRetCodeWhitelist 鉴权类 RetCode 白名单(D6)。实测网关(2026-06-11 实探): +// 鉴权失败以 HTTP 200 + RetCode 返回,401 仅作防御性分支保留。 +// 174 "Token Not Exists":伪造与已过期的 Bearer 同为 174(已实测确认);属网关 +// 前置鉴权拒绝,业务必未执行,重放一次安全。网关团队书面确认仍待补档(spec §7)。 +// 170(缺签名,oauth 请求恒带 Bearer 不会触发)、171/172(AK/SK 路径)不入列。 +var authRetCodeWhitelist = map[int]bool{ + 174: true, // Token Not Exists:无效或过期 Bearer +} + +// isAuthFailure 判定是否鉴权类失败:HTTP 401 或 body RetCode 在白名单(网关前置鉴权,业务必未执行)。 +// 注意 SDK 行为:HttpClient.Send 对 status>=400 返回 (nil, StatusError)(vendor +// private/protocol/http/client.go),且默认 errorHTTPHandler 先于本 handler 把它 +// 转成 uerr.ServerError —— 401 只会出现在 err 里、resp 必为 nil;resp 路径仅作 +// RetCode 白名单(HTTP 200 + 鉴权 RetCode)的判定入口。 +func isAuthFailure(resp *http.HttpResponse, err error) bool { + switch e := err.(type) { + case http.StatusError: + if e.StatusCode == 401 { + return true + } + case uerr.ServerError: + if e.StatusCode() == 401 { + return true + } + } + if resp == nil { + return false + } + var body struct { + RetCode int `json:"RetCode"` + } + if jerr := json.Unmarshal(resp.GetBody(), &body); jerr == nil { + return authRetCodeWhitelist[body.RetCode] + } + return false +} + +// newOAuthRetryHandler 反应式兜底(D6,Google 式):鉴权失败 → 刷新 → 自动重放一次。 +// 重放直接走 httpClient.Send,不再经过本 handler,天然不会循环。 +// 刷新对象是构造本 client 的 ac(而非 ConfigIns):cmd/root.go 的 os.Args 扫描 +// 识别不了 -p X/--profile=X 等形式,ConfigIns 可能指向另一个 profile,错刷会把 +// 别人的 Bearer 重放到当前请求上。ac 在所有 oauth 路径上都是 manager 持有的指针 +// (GetAggConfigByProfile/Append 直接存取同一指针),refreshAndSave 的写回因此可靠。 +// req 的 Authorization 由 SetHeader 以 map 赋值覆盖(不会叠加重复头),且 body 中 +// 的签名参数已被 newCredHeaderInjector 剥离,重放仍满足「oauth 请求只带 Bearer」不变式。 +func newOAuthRetryHandler(credConfig *CredentialConfig, ac *AggConfig) sdk.HttpResponseHandler { + return func(c *sdk.Client, req *http.HttpRequest, resp *http.HttpResponse, err error) (*http.HttpResponse, error) { + if ac == nil || credConfig.AuthMode != AuthModeOAuth || credConfig.AccessToken == "" { + return resp, err + } + if !isAuthFailure(resp, err) { + return resp, err + } + // 刷新(flock 串行化 + 拿锁后重读,见 refreshAndSave) + if rerr := refreshAndSave(ac, AggConfigListIns); rerr != nil { + LogWarn(fmt.Sprintf("oauth reactive refresh failed: %v", Redact(rerr.Error()))) + return resp, err + } + credConfig.AccessToken = ac.AccessToken + _ = req.SetHeader("Authorization", "Bearer "+credConfig.AccessToken) // SetHeader 恒返回 nil + LogInfo("auth failure detected, token refreshed, replaying request once") + hc := http.NewHttpClient() + nresp, nerr := hc.Send(req) + if serr, ok := nerr.(http.StatusError); ok { + // 本 handler 位于链尾,重放结果不会再经过默认 errorHTTPHandler, + // 在此对齐其行为:StatusError → uerr.ServerError + nerr = uerr.NewServerStatusError(serr.StatusCode, serr.Message) + } + return nresp, nerr + } +} + +// NewClient will return a aggregate client. +// ac 是构造来源 profile(oauth 401 反应式刷新的对象),允许为 nil(此时不重放)。 +func NewClient(config *sdk.Config, credConfig *CredentialConfig, ac *AggConfig) *Client { + var handler sdk.RequestHandler = func(c *sdk.Client, req request.Common) (request.Common, error) { + err := req.SetProjectId(PickResourceID(req.GetProjectId())) return req, err } - credential := &auth.Credential{ - PublicKey: credConfig.PublicKey, - PrivateKey: credConfig.PrivateKey, + injectCredHeader := newCredHeaderInjector(credConfig) + oauthRetry := newOAuthRetryHandler(credConfig, ac) + // 不变式:一个请求只携带一种凭据机制(auth_mode 唯一决定走哪种)。 + // oauth profile 会保留旧 AK/SK 在磁盘上(供 auth logout 恢复),但它们必须 + // 对 SDK 签名器不可见——否则签名参数与 Bearer 同时上行,网关先验签名 + // 直接报 RetCode 171 Signature VerifyAC Error。oauth 模式下凭据留空; + // 注意 SDK 编码器对空密钥仍会附加 Signature 参数,由 newCredHeaderInjector + // 剥离,最终 Bearer 是唯一凭据。 + credential := &auth.Credential{} + if credConfig.AuthMode != AuthModeOAuth { + credential.PublicKey = credConfig.PublicKey + credential.PrivateKey = credConfig.PrivateKey } var ( uaccountClient = *uaccount.NewClient(config, credential) @@ -97,51 +211,67 @@ func NewClient(config *sdk.Config, credConfig *CredentialConfig) *Client { uaccountClient.Client.AddRequestHandler(handler) uaccountClient.Client.AddHttpRequestHandler(injectCredHeader) + uaccountClient.Client.AddHttpResponseHandler(oauthRetry) uhostClient.Client.AddRequestHandler(handler) uhostClient.Client.AddHttpRequestHandler(injectCredHeader) + uhostClient.Client.AddHttpResponseHandler(oauthRetry) unetClient.Client.AddRequestHandler(handler) unetClient.Client.AddHttpRequestHandler(injectCredHeader) + unetClient.Client.AddHttpResponseHandler(oauthRetry) vpcClient.Client.AddRequestHandler(handler) vpcClient.Client.AddHttpRequestHandler(injectCredHeader) + vpcClient.Client.AddHttpResponseHandler(oauthRetry) udpnClient.Client.AddRequestHandler(handler) udpnClient.Client.AddHttpRequestHandler(injectCredHeader) + udpnClient.Client.AddHttpResponseHandler(oauthRetry) pathxClient.Client.AddRequestHandler(handler) pathxClient.Client.AddHttpRequestHandler(injectCredHeader) + pathxClient.Client.AddHttpResponseHandler(oauthRetry) udiskClient.Client.AddRequestHandler(handler) udiskClient.Client.AddHttpRequestHandler(injectCredHeader) + udiskClient.Client.AddHttpResponseHandler(oauthRetry) ulbClient.Client.AddRequestHandler(handler) ulbClient.Client.AddHttpRequestHandler(injectCredHeader) + ulbClient.Client.AddHttpResponseHandler(oauthRetry) udbClient.Client.AddRequestHandler(handler) udbClient.Client.AddHttpRequestHandler(injectCredHeader) + udbClient.Client.AddHttpResponseHandler(oauthRetry) umemClient.Client.AddRequestHandler(handler) umemClient.Client.AddHttpRequestHandler(injectCredHeader) + umemClient.Client.AddHttpResponseHandler(oauthRetry) uphostClient.Client.AddRequestHandler(handler) uphostClient.Client.AddHttpRequestHandler(injectCredHeader) + uphostClient.Client.AddHttpResponseHandler(oauthRetry) puhostClient.Client.AddRequestHandler(handler) puhostClient.Client.AddHttpRequestHandler(injectCredHeader) + puhostClient.Client.AddHttpResponseHandler(oauthRetry) pudbClient.Client.AddRequestHandler(handler) pudbClient.Client.AddHttpRequestHandler(injectCredHeader) + pudbClient.Client.AddHttpResponseHandler(oauthRetry) pumemClient.Client.AddRequestHandler(handler) pumemClient.Client.AddHttpRequestHandler(injectCredHeader) + pumemClient.Client.AddHttpResponseHandler(oauthRetry) ppathxClient.Client.AddRequestHandler(handler) ppathxClient.Client.AddHttpRequestHandler(injectCredHeader) + ppathxClient.Client.AddHttpResponseHandler(oauthRetry) ulhostClient.Client.AddRequestHandler(handler) ulhostClient.Client.AddHttpRequestHandler(injectCredHeader) + ulhostClient.Client.AddHttpResponseHandler(oauthRetry) return &Client{ uaccountClient, diff --git a/base/client_test.go b/base/client_test.go new file mode 100644 index 0000000000..530cf5dadd --- /dev/null +++ b/base/client_test.go @@ -0,0 +1,414 @@ +// base/client_test.go +package base + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + uhttp "github.com/ucloud/ucloud-sdk-go/private/protocol/http" +) + +func injectorHeaders(t *testing.T, cred *CredentialConfig) map[string]string { + t.Helper() + h := newCredHeaderInjector(cred) + req, err := h(nil, uhttp.NewHttpRequest()) + if err != nil { + t.Fatal(err) + } + return req.GetHeaderMap() +} + +// oauth 模式:注入 Authorization Bearer +func TestInjectorOAuthBearer(t *testing.T) { + headers := injectorHeaders(t, &CredentialConfig{AuthMode: AuthModeOAuth, AccessToken: "tok123"}) + if headers["Authorization"] != "Bearer tok123" { + t.Errorf("Authorization = %q, want Bearer tok123", headers["Authorization"]) + } +} + +// CRITICAL 回归:aksk 模式(含 CloudShell Cookie 注入)头部行为零变化 +func TestInjectorAkskAndCloudShellUnchanged(t *testing.T) { + // aksk:Cookie/Csrf-Token 照旧(空值也照旧 set),绝不出现 Authorization + h1 := injectorHeaders(t, &CredentialConfig{PublicKey: "pub", PrivateKey: "pri"}) + if _, ok := h1["Authorization"]; ok { + t.Error("aksk mode must NOT inject Authorization header") + } + if v, ok := h1["Cookie"]; !ok || v != "" { + t.Errorf("Cookie header behavior changed: %q %v", v, ok) + } + // CloudShell:Cookie/Csrf-Token 注入照旧 + h2 := injectorHeaders(t, &CredentialConfig{Cookie: "ck", CSRFToken: "cs"}) + if h2["Cookie"] != "ck" || h2["Csrf-Token"] != "cs" { + t.Errorf("cloudshell headers changed: %v", h2) + } + if _, ok := h2["Authorization"]; ok { + t.Error("cloudshell mode must NOT inject Authorization header") + } +} + +// oauth 模式但 token 为空:不注入(让网关报错而不是发送 "Bearer ") +func TestInjectorOAuthEmptyToken(t *testing.T) { + headers := injectorHeaders(t, &CredentialConfig{AuthMode: AuthModeOAuth}) + if _, ok := headers["Authorization"]; ok { + t.Error("empty token must not inject Authorization") + } +} + +// recordedRequest 记录业务请求实际携带的 header 与全部参数(query + form 合并) +type recordedRequest struct { + header http.Header + params url.Values +} + +// bizRecorderServer 模拟业务网关:记录请求并返回成功响应 +func bizRecorderServer(t *testing.T, rec *recordedRequest) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { + t.Errorf("parse form: %v", err) + } + rec.header = r.Header.Clone() + rec.params = r.Form // r.Form 含 URL query + body form,两路都覆盖 + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"RetCode":0,"Action":"GetRegionResponse"}`) + })) +} + +func callGetRegion(t *testing.T, ac *AggConfig, rec *recordedRequest) { + t.Helper() + // GetBizClient 会改写包级全局 ClientConfig/AuthCredential,恢复现场避免测试顺序耦合 + oldClientConfig, oldAuthCredential := ClientConfig, AuthCredential + t.Cleanup(func() { + ClientConfig, AuthCredential = oldClientConfig, oldAuthCredential + }) + bc, err := GetBizClient(ac) + if err != nil { + t.Fatal(err) + } + if _, err := bc.GetRegion(bc.NewGetRegionRequest()); err != nil { + t.Fatalf("GetRegion failed: %v", err) + } + if rec.params == nil { + t.Fatal("server did not record any request") + } +} + +// CRITICAL 缺陷回归(RetCode 171):oauth profile 残留 AK/SK(供 logout 恢复)时, +// 请求必须只携带 Bearer 一种凭据,绝不能同时出现 SDK 签名参数。 +func TestOAuthProfileWithRetainedKeysDoesNotSign(t *testing.T) { + rec := &recordedRequest{} + s := bizRecorderServer(t, rec) + defer s.Close() + + ac := &AggConfig{ + Profile: "oauth-leftover", BaseURL: s.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + Region: "cn-bj2", AuthMode: AuthModeOAuth, AccessToken: "tok", + ExpiresAt: time.Now().Add(time.Hour).Unix(), + PublicKey: "leftover-pub", PrivateKey: "leftover-pri", + } + callGetRegion(t, ac, rec) + + if got := rec.header.Get("Authorization"); got != "Bearer tok" { + t.Errorf("Authorization = %q, want %q", got, "Bearer tok") + } + for _, k := range []string{"Signature", "PublicKey"} { + if v, ok := rec.params[k]; ok { + t.Errorf("oauth profile must not send signature param %s=%v (one request carries exactly one credential)", k, v) + } + } +} + +// 401 自动重放矩阵(D6 反应式兜底):401→刷新→重放成功;aksk 模式不重放。 +// 注意 SDK 行为:HttpClient.Send 对 status>=400 返回 (nil, StatusError), +// 401 的 body 在 handler 层不可见,鉴权失败只能从 err 判定。 +func TestOAuthRetryHandler(t *testing.T) { + apiCalls := 0 + api := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalls++ + if r.Header.Get("Authorization") == "Bearer good" { + fmt.Fprint(w, `{"RetCode":0,"Action":"GetRegionResponse"}`) + return + } + w.WriteHeader(401) + fmt.Fprint(w, `{"RetCode":170,"Message":"token expired"}`) + })) + defer api.Close() + oauth := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"good","refresh_token":"rt2","expires_in":3600}`) + })) + defer oauth.Close() + + ac := &AggConfig{ + Profile: "pr", Active: true, BaseURL: api.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + AuthMode: AuthModeOAuth, AccessToken: "bad", RefreshToken: "rt1", + ExpiresAt: 9999999999, OAuthBaseURL: oauth.URL, // 未到期 → 不触发主动刷新,逼出反应式路径 + } + m := newTestManager(t, ac) + prevIns, prevList := ConfigIns, AggConfigListIns + prevCC, prevAC := ClientConfig, AuthCredential + ConfigIns, AggConfigListIns = ac, m + t.Cleanup(func() { + ConfigIns, AggConfigListIns = prevIns, prevList + ClientConfig, AuthCredential = prevCC, prevAC + }) + + client, err := GetBizClient(ac) + if err != nil { + t.Fatal(err) + } + resp, err := client.GetRegion(client.NewGetRegionRequest()) + if err != nil { + t.Fatalf("replay should succeed: %v", err) + } + if resp.GetRetCode() != 0 { + t.Errorf("RetCode = %d", resp.GetRetCode()) + } + if apiCalls != 2 { + t.Errorf("expect 1 fail + 1 replay = 2 api calls, got %d", apiCalls) + } + var creds []CredentialConfig + raw, _ := ioutil.ReadFile(".ucloud/credential.json") + json.Unmarshal(raw, &creds) + var persisted *CredentialConfig + for i := range creds { + if creds[i].Profile == "pr" { + persisted = &creds[i] + } + } + if persisted == nil || persisted.AccessToken != "good" || persisted.RefreshToken != "rt2" { + t.Errorf("refreshed token and rotated refresh_token must be persisted: %s", raw) + } +} + +// RetCode 白名单路径(实测网关行为):鉴权失败返回 HTTP 200 + RetCode 174 "Token Not Exists" +// (无效与过期 Bearer 同码,2026-06-11 实测)。SDK 管道:200 时 Send 返回 (resp, nil), +// 默认 errorHTTPHandler 不动 err==nil,body 在本 handler 可读 → 走 isAuthFailure 的 +// resp-body 白名单分支:刷新 → 重放一次成功。 +func TestOAuthRetryHandlerRetCode174(t *testing.T) { + apiCalls := 0 + api := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalls++ + w.Header().Set("Content-Type", "application/json") + if r.Header.Get("Authorization") == "Bearer good" { + fmt.Fprint(w, `{"RetCode":0,"Action":"GetRegionResponse"}`) + return + } + fmt.Fprint(w, `{"RetCode":174,"Message":"Token Not Exists"}`) + })) + defer api.Close() + oauth := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"good","refresh_token":"rt2","expires_in":3600}`) + })) + defer oauth.Close() + + ac := &AggConfig{ + Profile: "p174", Active: true, BaseURL: api.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + AuthMode: AuthModeOAuth, AccessToken: "bad", RefreshToken: "rt1", + ExpiresAt: 9999999999, OAuthBaseURL: oauth.URL, // 未到期 → 不触发主动刷新,逼出反应式路径 + } + m := newTestManager(t, ac) + prevIns, prevList := ConfigIns, AggConfigListIns + prevCC, prevAC := ClientConfig, AuthCredential + ConfigIns, AggConfigListIns = ac, m + t.Cleanup(func() { + ConfigIns, AggConfigListIns = prevIns, prevList + ClientConfig, AuthCredential = prevCC, prevAC + }) + + client, err := GetBizClient(ac) + if err != nil { + t.Fatal(err) + } + resp, err := client.GetRegion(client.NewGetRegionRequest()) + if err != nil { + t.Fatalf("replay should succeed: %v", err) + } + if resp.GetRetCode() != 0 { + t.Errorf("RetCode = %d", resp.GetRetCode()) + } + if apiCalls != 2 { + t.Errorf("expect 1 fail + 1 replay = 2 api calls, got %d", apiCalls) + } + var creds []CredentialConfig + raw, _ := ioutil.ReadFile(".ucloud/credential.json") + json.Unmarshal(raw, &creds) + var persisted *CredentialConfig + for i := range creds { + if creds[i].Profile == "p174" { + persisted = &creds[i] + } + } + if persisted == nil || persisted.AccessToken != "good" || persisted.RefreshToken != "rt2" { + t.Errorf("refreshed token and rotated refresh_token must be persisted: %s", raw) + } +} + +// 负路径(174 持续):重放后仍 174 → 只重放一次(共 2 次 api 调用,不循环), +// RetCode 由 SDK 默认 errorHandler 转为 ServerCodeError 上浮。 +func TestOAuthRetryHandlerRetCode174Persists(t *testing.T) { + apiCalls := 0 + api := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalls++ + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"RetCode":174,"Message":"Token Not Exists"}`) + })) + defer api.Close() + oauth := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"good","refresh_token":"rt2","expires_in":3600}`) + })) + defer oauth.Close() + + ac := &AggConfig{ + Profile: "p174s", Active: true, BaseURL: api.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + AuthMode: AuthModeOAuth, AccessToken: "bad", RefreshToken: "rt1", + ExpiresAt: 9999999999, OAuthBaseURL: oauth.URL, + } + m := newTestManager(t, ac) + prevList, prevCC, prevAC := AggConfigListIns, ClientConfig, AuthCredential + AggConfigListIns = m + t.Cleanup(func() { AggConfigListIns, ClientConfig, AuthCredential = prevList, prevCC, prevAC }) + + client, err := GetBizClient(ac) + if err != nil { + t.Fatal(err) + } + if _, err := client.GetRegion(client.NewGetRegionRequest()); err == nil { + t.Error("persistent RetCode 174 must surface an error after single replay") + } + if apiCalls != 2 { + t.Errorf("exactly one replay allowed (no loop): got %d api calls", apiCalls) + } +} + +// 负路径 (a):刷新失败(invalid_grant)→ 原始 401 错误上浮,不重放、不 panic +func TestOAuthRetryHandlerRefreshFails(t *testing.T) { + apiCalls := 0 + api := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalls++ + w.WriteHeader(401) + fmt.Fprint(w, `{"RetCode":170,"Message":"token expired"}`) + })) + defer api.Close() + oauth := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(400) + fmt.Fprint(w, `{"error":"invalid_grant"}`) + })) + defer oauth.Close() + + ac := &AggConfig{ + Profile: "prf", Active: true, BaseURL: api.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + AuthMode: AuthModeOAuth, AccessToken: "bad", RefreshToken: "rt1", + ExpiresAt: 9999999999, OAuthBaseURL: oauth.URL, + } + m := newTestManager(t, ac) + prevList, prevCC, prevAC := AggConfigListIns, ClientConfig, AuthCredential + AggConfigListIns = m + t.Cleanup(func() { AggConfigListIns, ClientConfig, AuthCredential = prevList, prevCC, prevAC }) + + client, err := GetBizClient(ac) + if err != nil { + t.Fatal(err) + } + if _, err := client.GetRegion(client.NewGetRegionRequest()); err == nil { + t.Error("refresh failure must surface the original 401 error") + } + if apiCalls != 1 { + t.Errorf("refresh failed, must not replay: got %d api calls", apiCalls) + } +} + +// 负路径 (b):重放后仍 401 → 只重放一次(共 2 次 api 调用,不循环),错误上浮 +func TestOAuthRetryHandlerReplayStill401(t *testing.T) { + apiCalls := 0 + api := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalls++ + w.WriteHeader(401) + fmt.Fprint(w, `{"RetCode":170,"Message":"still no"}`) + })) + defer api.Close() + oauth := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"good","refresh_token":"rt2","expires_in":3600}`) + })) + defer oauth.Close() + + ac := &AggConfig{ + Profile: "prs", Active: true, BaseURL: api.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + AuthMode: AuthModeOAuth, AccessToken: "bad", RefreshToken: "rt1", + ExpiresAt: 9999999999, OAuthBaseURL: oauth.URL, + } + m := newTestManager(t, ac) + prevList, prevCC, prevAC := AggConfigListIns, ClientConfig, AuthCredential + AggConfigListIns = m + t.Cleanup(func() { AggConfigListIns, ClientConfig, AuthCredential = prevList, prevCC, prevAC }) + + client, err := GetBizClient(ac) + if err != nil { + t.Fatal(err) + } + if _, err := client.GetRegion(client.NewGetRegionRequest()); err == nil { + t.Error("replay still 401 must surface error") + } + if apiCalls != 2 { + t.Errorf("exactly one replay allowed (no loop): got %d api calls", apiCalls) + } +} + +func TestOAuthRetryHandlerSkipsAksk(t *testing.T) { + apiCalls := 0 + api := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + apiCalls++ + w.WriteHeader(401) + fmt.Fprint(w, `{"RetCode":170,"Message":"x"}`) + })) + defer api.Close() + ac := &AggConfig{ + Profile: "pa", Active: true, BaseURL: api.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + PublicKey: "pub", PrivateKey: "pri", + } + _ = newTestManager(t, ac) + prevCC, prevAC := ClientConfig, AuthCredential + t.Cleanup(func() { ClientConfig, AuthCredential = prevCC, prevAC }) + client, err := GetBizClient(ac) + if err != nil { + t.Fatal(err) + } + if _, err := client.GetRegion(client.NewGetRegionRequest()); err == nil { + t.Error("aksk 401 should surface error, not replay-refresh") + } + if apiCalls != 1 { + t.Errorf("aksk mode must not replay, got %d calls", apiCalls) + } +} + +// 反向护栏:aksk profile 照旧签名(Signature + PublicKey 必须在场,且无 Bearer) +func TestAkskProfileStillSigns(t *testing.T) { + rec := &recordedRequest{} + s := bizRecorderServer(t, rec) + defer s.Close() + + ac := &AggConfig{ + Profile: "aksk", BaseURL: s.URL, Timeout: 15, MaxRetryTimes: intPtr(0), + Region: "cn-bj2", PublicKey: "pub", PrivateKey: "pri", + } + callGetRegion(t, ac, rec) + + for _, k := range []string{"Signature", "PublicKey"} { + if _, ok := rec.params[k]; !ok { + t.Errorf("aksk profile must sign requests, missing param %s; got params %v", k, rec.params) + } + } + if _, ok := rec.header["Authorization"]; ok { + t.Error("aksk profile must not send Authorization header") + } +} diff --git a/base/config.go b/base/config.go index b4ecb7f72d..09f352102a 100644 --- a/base/config.go +++ b/base/config.go @@ -39,7 +39,7 @@ const DefaultBaseURL = "https://api.ucloud.cn/" const DefaultProfile = "default" // Version 版本号 -const Version = "0.3.0" +const Version = "0.3.1" var UserAgent = fmt.Sprintf("UCloud-CLI/%s", Version) @@ -95,6 +95,7 @@ type CLIConfig struct { Active bool `json:"active"` //是否生效 MaxRetryTimes *int `json:"max_retry_times"` AgreeUploadLog bool `json:"agree_upload_log"` + OAuthBaseURL string `json:"oauth_base_url,omitempty"` } // CredentialConfig credential element @@ -104,6 +105,11 @@ type CredentialConfig struct { Cookie string `json:"cookie"` CSRFToken string `json:"csrf_token"` Profile string `json:"profile"` + + AuthMode string `json:"auth_mode,omitempty"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` } // AggConfig 聚合配置 config+credential @@ -121,6 +127,11 @@ type AggConfig struct { CSRFToken string `json:"csrf_token"` MaxRetryTimes *int `json:"max_retry_times"` AgreeUploadLog bool `json:"agree_upload_log"` + AuthMode string `json:"auth_mode,omitempty"` + AccessToken string `json:"access_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` + ExpiresAt int64 `json:"expires_at,omitempty"` + OAuthBaseURL string `json:"oauth_base_url,omitempty"` } // ConfigPublicKey 输入公钥 @@ -213,6 +224,7 @@ func (p *AggConfig) copyToCLIConfig(target *CLIConfig) { target.Active = p.Active target.MaxRetryTimes = p.MaxRetryTimes target.AgreeUploadLog = p.AgreeUploadLog + target.OAuthBaseURL = p.OAuthBaseURL } func (p *AggConfig) copyToCredentialConfig(target *CredentialConfig) { @@ -221,22 +233,26 @@ func (p *AggConfig) copyToCredentialConfig(target *CredentialConfig) { target.PublicKey = p.PublicKey target.Cookie = p.Cookie target.CSRFToken = p.CSRFToken + target.AuthMode = p.AuthMode + target.AccessToken = p.AccessToken + target.RefreshToken = p.RefreshToken + target.ExpiresAt = p.ExpiresAt } // AggConfigManager 配置管理 type AggConfigManager struct { activeProfile string configs map[string]*AggConfig - configFile *os.File - credFile *os.File + configPath string + credPath string } // NewAggConfigManager create instance -func NewAggConfigManager(cfgFile, credFile *os.File) (*AggConfigManager, error) { +func NewAggConfigManager(configPath, credPath string) (*AggConfigManager, error) { manager := &AggConfigManager{ configs: make(map[string]*AggConfig), - configFile: cfgFile, - credFile: credFile, + configPath: configPath, + credPath: credPath, } err := manager.Load() @@ -278,7 +294,8 @@ func (p *AggConfigManager) Append(config *AggConfig) error { // UpdateAggConfig update AggConfig append if not exist func (p *AggConfigManager) UpdateAggConfig(config *AggConfig) error { - if _, ok := p.configs[config.Profile]; !ok { + existing, ok := p.configs[config.Profile] + if !ok { return p.Append(config) } @@ -288,6 +305,12 @@ func (p *AggConfigManager) UpdateAggConfig(config *AggConfig) error { } p.activeProfile = config.Profile } + // 调用方传入的可能不是 map 内条目本身(如 --profile 回退默认配置的场景): + // 以传入值为准覆盖 map 条目,否则 Save 会把旧数据落盘、静默丢弃调用方的修改。 + // 保留 map 指针不变,已持有该指针的别名(如 oauth 刷新写回)继续有效。 + if existing != config { + *existing = *config + } return p.Save() } @@ -338,6 +361,11 @@ func (p *AggConfigManager) Load() error { Active: config.Active, MaxRetryTimes: config.MaxRetryTimes, AgreeUploadLog: config.AgreeUploadLog, + AuthMode: cred.AuthMode, + AccessToken: cred.AccessToken, + RefreshToken: cred.RefreshToken, + ExpiresAt: cred.ExpiresAt, + OAuthBaseURL: config.OAuthBaseURL, } } @@ -478,8 +506,8 @@ func (p *AggConfigManager) Save() error { aggConfig.copyToCredentialConfig(credConfig) credcs = append(credcs, credConfig) } - aerr := WriteJSONFile(clics, p.configFile.Name()) - berr := WriteJSONFile(credcs, p.credFile.Name()) + aerr := WriteJSONFileAtomic(clics, p.configPath) + berr := WriteJSONFileAtomic(credcs, p.credPath) if aerr != nil && berr != nil { return fmt.Errorf("save cli config failed: %v | save credentail failed: %v", aerr, berr) @@ -557,8 +585,11 @@ func (p *AggConfigManager) GetActiveAggConfigName() string { func (p *AggConfigManager) parseCLIConfigs() ([]CLIConfig, error) { var configs []CLIConfig - rawConfig, err := ioutil.ReadAll(p.configFile) + rawConfig, err := ioutil.ReadFile(p.configPath) if err != nil { + if os.IsNotExist(err) { + return nil, nil + } return nil, err } if len(rawConfig) == 0 { @@ -580,8 +611,11 @@ func (p *AggConfigManager) parseCLIConfigs() ([]CLIConfig, error) { func (p *AggConfigManager) parseCredentials() ([]CredentialConfig, error) { var credentials []CredentialConfig - rawCred, err := ioutil.ReadAll(p.credFile) + rawCred, err := ioutil.ReadFile(p.credPath) if err != nil { + if os.IsNotExist(err) { + return nil, nil + } return nil, err } @@ -602,6 +636,8 @@ func ListAggConfig(json bool) { for idx, ac := range aggConfigs { aggConfigs[idx].PrivateKey = MosaicString(ac.PrivateKey, 8, 5) aggConfigs[idx].PublicKey = MosaicString(ac.PublicKey, 8, 5) + aggConfigs[idx].AccessToken = MosaicString(ac.AccessToken, 8, 5) + aggConfigs[idx].RefreshToken = MosaicString(ac.RefreshToken, 8, 5) } if json { err := PrintJSON(aggConfigs, os.Stdout) @@ -609,7 +645,7 @@ func ListAggConfig(json bool) { HandleError(err) } } else { - PrintTable(aggConfigs, []string{"Profile", "Active", "ProjectID", "Region", "Zone", "BaseURL", "Timeout", "PublicKey", "PrivateKey", "MaxRetryTimes", "AgreeUploadLog"}) + PrintTable(aggConfigs, []string{"Profile", "Active", "AuthMode", "ProjectID", "Region", "Zone", "BaseURL", "Timeout", "PublicKey", "PrivateKey", "MaxRetryTimes", "AgreeUploadLog"}) } } @@ -731,26 +767,21 @@ func GetBizClient(ac *AggConfig) (*Client, error) { MaxRetries: *ac.MaxRetryTimes, } AuthCredential = &CredentialConfig{ - PublicKey: ac.PublicKey, - PrivateKey: ac.PrivateKey, - Cookie: ac.Cookie, - CSRFToken: ac.CSRFToken, + PublicKey: ac.PublicKey, + PrivateKey: ac.PrivateKey, + Cookie: ac.Cookie, + CSRFToken: ac.CSRFToken, + AuthMode: ac.AuthMode, + AccessToken: ac.AccessToken, + RefreshToken: ac.RefreshToken, + ExpiresAt: ac.ExpiresAt, } - return NewClient(ClientConfig, AuthCredential), err + return NewClient(ClientConfig, AuthCredential, ac), err } func InitConfigInCloudShell() error { - configFile, err := os.OpenFile(ConfigFilePath, os.O_CREATE|os.O_RDONLY, LocalFileMode) - if err != nil { - return err - } - credFile, err := os.OpenFile(CredentialFilePath, os.O_CREATE|os.O_RDONLY, LocalFileMode) - if err != nil { - return err - } - - data, err := ioutil.ReadAll(credFile) - if err != nil { + data, err := ioutil.ReadFile(CredentialFilePath) + if err != nil && !os.IsNotExist(err) { return err } if len(data) > 0 { @@ -772,8 +803,8 @@ func InitConfigInCloudShell() error { return err } - AggConfigM.credFile = credFile - AggConfigM.configFile = configFile + AggConfigM.credPath = CredentialFilePath + AggConfigM.configPath = ConfigFilePath ins, err := AggConfigM.GetActiveAggConfig() if err != nil { return err @@ -789,16 +820,8 @@ func InitConfigInCloudShell() error { // InitConfig 初始化配置 func InitConfig() { - configFile, err := os.OpenFile(ConfigFilePath, os.O_CREATE|os.O_RDONLY, LocalFileMode) - if err != nil && !os.IsNotExist(err) { - HandleError(err) - } - credFile, err := os.OpenFile(CredentialFilePath, os.O_CREATE|os.O_RDONLY, LocalFileMode) - if err != nil && !os.IsNotExist(err) { - HandleError(err) - } - - AggConfigListIns, err = NewAggConfigManager(configFile, credFile) + var err error + AggConfigListIns, err = NewAggConfigManager(ConfigFilePath, CredentialFilePath) if err != nil { LogError(err.Error()) return @@ -843,6 +866,7 @@ func mergeConfigIns(ins *AggConfig) { if Global.PublicKey != "" && Global.PrivateKey != "" { ins.PrivateKey = Global.PrivateKey ins.PublicKey = Global.PublicKey + ins.AuthMode = "" // flag 显式给了 AK/SK:走签名,抑制 Bearer 注入(D5) } } diff --git a/base/config_test.go b/base/config_test.go index 1bb244ee8b..1f906ce617 100644 --- a/base/config_test.go +++ b/base/config_test.go @@ -33,17 +33,7 @@ func TestAggConfigManager(t *testing.T) { } }() - configFile, err := os.OpenFile(".ucloud/config.json", os.O_CREATE|os.O_RDONLY, LocalFileMode) - if err != nil { - t.Error(err) - } - - credFile, err := os.OpenFile(".ucloud/credential.json", os.O_CREATE|os.O_RDONLY, LocalFileMode) - if err != nil { - t.Error(err) - } - - acManager, err := NewAggConfigManager(configFile, credFile) + acManager, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") if err != nil { t.Error(err) } @@ -63,27 +53,121 @@ func TestEmptyAggConfigManager(t *testing.T) { } }() - configFile, err := os.OpenFile(".ucloud/config.json", os.O_CREATE|os.O_RDONLY, LocalFileMode) + acManager, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") if err != nil { t.Error(err) } - credFile, err := os.OpenFile(".ucloud/credential.json", os.O_CREATE|os.O_RDONLY, LocalFileMode) + err = acManager.Load() if err != nil { - t.Error(err) + t.Fatal(err) } - acManager, err := NewAggConfigManager(configFile, credFile) + if len(acManager.configs) != 0 { + t.Errorf("expect length of configs is 2, accpet %d", len(acManager.configs)) + } +} + +// CRITICAL 回归:旧 credential.json(无 oauth 字段)必须照常加载且 Save 后不丢数据 +func TestOldCredentialCompat(t *testing.T) { + os.MkdirAll(".ucloud", 0700) + defer os.RemoveAll(".ucloud") + ioutil.WriteFile(".ucloud/config.json", []byte(cliConfigJSON), LocalFileMode) + ioutil.WriteFile(".ucloud/credential.json", []byte(credentialJSON), LocalFileMode) + + m, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") if err != nil { - t.Error(err) + t.Fatal(err) + } + ac, ok := m.GetAggConfigByProfile("uweb") + if !ok { + t.Fatal("profile uweb missing") + } + if ac.AuthMode != "" || ac.AccessToken != "" { + t.Errorf("old file should yield empty oauth fields, got %+v", ac) } + if ac.PublicKey == "" { + t.Error("aksk fields must survive") + } + if err := m.Save(); err != nil { + t.Fatal(err) + } +} - err = acManager.Load() +// oauth 字段写入后能读回(含轮换写回场景的字段完整性) +func TestOAuthFieldsRoundTrip(t *testing.T) { + os.MkdirAll(".ucloud", 0700) + defer os.RemoveAll(".ucloud") + m, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") if err != nil { t.Fatal(err) } + ac := &AggConfig{ + Profile: "oauthp", Active: true, BaseURL: DefaultBaseURL, Timeout: 15, + MaxRetryTimes: intPtr(3), + AuthMode: AuthModeOAuth, AccessToken: "at", RefreshToken: "rt", ExpiresAt: 1234567890, + OAuthBaseURL: "https://oauth.example.com", + } + if err := m.Append(ac); err != nil { + t.Fatal(err) + } - if len(acManager.configs) != 0 { - t.Errorf("expect length of configs is 2, accpet %d", len(acManager.configs)) + // 重新读盘验证 + m2, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") + if err != nil { + t.Fatal(err) + } + got, ok := m2.GetAggConfigByProfile("oauthp") + if !ok { + t.Fatal("profile oauthp missing after reload") + } + if got.AuthMode != AuthModeOAuth || got.AccessToken != "at" || got.RefreshToken != "rt" || + got.ExpiresAt != 1234567890 || got.OAuthBaseURL != "https://oauth.example.com" { + t.Errorf("oauth fields lost on round trip: %+v", got) + } +} + +// UpdateAggConfig 必须以传入的 config 为准:当传入指针与 map 内条目不是同一个对象时 +// (如 `ucloud --profile <不存在>` 回退到包级默认 ConfigIns 而盘上已有同名 profile), +// 不能静默把 map 里的旧数据存盘、丢掉调用方的数据。 +func TestUpdateAggConfigPointerMismatch(t *testing.T) { + os.MkdirAll(".ucloud", 0700) + defer os.RemoveAll(".ucloud") + m, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") + if err != nil { + t.Fatal(err) + } + old := &AggConfig{ + Profile: "x", Active: true, Region: "cn-bj2", Zone: "cn-bj2-04", + PublicKey: "oldpub", PrivateKey: "oldpri", + BaseURL: DefaultBaseURL, Timeout: 15, MaxRetryTimes: intPtr(3), + } + if err := m.Append(old); err != nil { + t.Fatal(err) + } + + // 独立构造的另一个指针,同 Profile、不同字段值 + fresh := &AggConfig{ + Profile: "x", Active: true, Region: "hk", Zone: "hk-02", + PublicKey: "newpub", PrivateKey: "newpri", + BaseURL: DefaultBaseURL, Timeout: 30, MaxRetryTimes: intPtr(5), + } + if err := m.UpdateAggConfig(fresh); err != nil { + t.Fatal(err) + } + + m2, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") + if err != nil { + t.Fatal(err) + } + got, ok := m2.GetAggConfigByProfile("x") + if !ok { + t.Fatal("profile x missing after reload") + } + if got.Region != "hk" || got.Zone != "hk-02" || got.PublicKey != "newpub" || + got.PrivateKey != "newpri" || got.Timeout != 30 { + t.Errorf("passed config was silently dropped, stale data persisted: %+v", got) } } + +func intPtr(i int) *int { return &i } diff --git a/base/log.go b/base/log.go index da036404df..bc43ad2e7a 100644 --- a/base/log.go +++ b/base/log.go @@ -49,16 +49,34 @@ func initLog() error { return nil } -func logCmd() { - args := make([]string, len(os.Args)) - copy(args, os.Args) +// redactCmdArgs 脱敏命令行参数:flag 值遮蔽(名单含 oauth 敏感词)+ 整体过 Redact 兜底 +func redactCmdArgs(osArgs []string) []string { + args := make([]string, len(osArgs)) + copy(args, osArgs) for idx, arg := range args { - for _, word := range []string{"password", "private-key", "public-key"} { + for _, word := range []string{"password", "private-key", "public-key", "code", "token", "authorization"} { if strings.Contains(arg, word) && idx <= len(args)-2 { args[idx+1] = strings.Repeat("*", 8) } } } + for idx := range args { + args[idx] = Redact(args[idx]) + } + return args +} + +// redactLogLines 日志出口统一脱敏(Phase 3 扩面:错误包装/调试输出经 Log* 的部分) +func redactLogLines(logs []string) []string { + out := make([]string, len(logs)) + for i, line := range logs { + out[i] = Redact(line) + } + return out +} + +func logCmd() { + args := redactCmdArgs(os.Args) LogInfo(fmt.Sprintf("command: %s", strings.Join(args, " "))) } @@ -83,6 +101,7 @@ func LogInfo(logs ...string) { if ok { return } + logs = redactLogLines(logs) mu.Lock() defer mu.Unlock() goID := curGoroutineID() @@ -100,6 +119,7 @@ func LogPrint(logs ...string) { if ok { return } + logs = redactLogLines(logs) mu.Lock() defer mu.Unlock() goID := curGoroutineID() @@ -118,6 +138,7 @@ func LogWarn(logs ...string) { if ok { return } + logs = redactLogLines(logs) mu.Lock() defer mu.Unlock() goID := curGoroutineID() @@ -136,6 +157,7 @@ func LogError(logs ...string) { if ok { return } + logs = redactLogLines(logs) mu.Lock() defer mu.Unlock() goID := curGoroutineID() diff --git a/base/log_test.go b/base/log_test.go new file mode 100644 index 0000000000..34acda29b5 --- /dev/null +++ b/base/log_test.go @@ -0,0 +1,79 @@ +// base/log_test.go +package base + +import ( + "bytes" + "os" + "strings" + "testing" +) + +func TestRedactCmdArgs(t *testing.T) { + args := []string{"ucloud", "login", "--private-key", "PRIKEY", "--code", "CODE1", "--token", "TOK1", "--authorization", "AUTH1"} + got := strings.Join(redactCmdArgs(args), " ") + for _, secret := range []string{"PRIKEY", "CODE1", "TOK1", "AUTH1"} { + if strings.Contains(got, secret) { + t.Errorf("redactCmdArgs leaked %q: %s", secret, got) + } + } + if !strings.Contains(got, "login") { + t.Error("non-sensitive args must be preserved") + } +} + +// 整行兜底:args 中内嵌的 URL query 形态的 code/token 也要被遮蔽 +func TestRedactCmdArgsURLForm(t *testing.T) { + args := []string{"ucloud", "x", "http://localhost/authorization?code=SEC&state=ST"} + got := strings.Join(redactCmdArgs(args), " ") + if strings.Contains(got, "SEC") { + t.Errorf("url-embedded code leaked: %s", got) + } +} + +// 出口接线测试:直接走 LogInfo,确认脱敏真的接在出口上(防止 redactLogLines 调用行被误删而无测试失败) +func TestLogInfoOutletWired(t *testing.T) { + if logger == nil { + t.Fatal("logger not initialized by package init") + } + // COMP_LINE 存在时 Log* 直接 return,须确保未设置 + if v, ok := os.LookupEnv("COMP_LINE"); ok { + os.Unsetenv("COMP_LINE") + t.Cleanup(func() { os.Setenv("COMP_LINE", v) }) + } + // 关闭上传路径,避免测试触网 + prevUpload := ConfigIns.AgreeUploadLog + ConfigIns.AgreeUploadLog = false + t.Cleanup(func() { ConfigIns.AgreeUploadLog = prevUpload }) + + var buf bytes.Buffer + prevOut := logger.Out + logger.SetOutput(&buf) + t.Cleanup(func() { logger.SetOutput(prevOut) }) + + LogInfo(`Authorization: Bearer SECRET-WIRE`) + + got := buf.String() + if got == "" { + t.Fatal("LogInfo wrote nothing to logger output") + } + if strings.Contains(got, "SECRET-WIRE") { + t.Errorf("LogInfo outlet leaked token: %s", got) + } + if !strings.Contains(got, "********") { + t.Errorf("LogInfo outlet missing redaction placeholder: %s", got) + } +} + +// 扩面:任何经 Log* 出口的行都不得泄漏 token(HandleError → LogError 同样被覆盖) +func TestLogOutputsRedacted(t *testing.T) { + lines := redactLogLines([]string{ + `request failed: Authorization: Bearer SECRET-AT`, + `refresh response: {"access_token":"SECRET-AT2","refresh_token":"SECRET-RT"}`, + }) + joined := strings.Join(lines, "\n") + for _, s := range []string{"SECRET-AT", "SECRET-AT2", "SECRET-RT"} { + if strings.Contains(joined, s) { + t.Errorf("log line leaked %q: %s", s, joined) + } + } +} diff --git a/base/oauth.go b/base/oauth.go new file mode 100644 index 0000000000..e7711724a0 --- /dev/null +++ b/base/oauth.go @@ -0,0 +1,410 @@ +// base/oauth.go +package base + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io/ioutil" + "net/http" + "net/url" + "os" + "regexp" + "strings" + "time" + + "github.com/gofrs/flock" + "github.com/mattn/go-isatty" +) + +// OAuth 常量(spec D2:client_secret 嵌入二进制为知情裁定,同 gcloud/gh)。 +const ( + defaultOAuthBaseURL = "https://oauth2.ucloud.cn" + oauthClientID = "WP77AwxvUgWt2JqaRCKn" + oauthClientSecret = "mksUQLod9VaUKMt3wESdgteTFCgVasiUwLSPqq5e" + oauthRedirectPath = "/authorization" + oauthScope = "openid email offline_access full_access" +) + +// BuildLoopbackRedirectURI 按后端规则拼 loopback redirect_uri:host 必须是字面量 localhost +// (127.0.0.1 会被后端拒),端口为内核分配的临时端口(>=1024)。 +func BuildLoopbackRedirectURI(port int) string { + return fmt.Sprintf("http://localhost:%d%s", port, oauthRedirectPath) +} + +// AuthModeOAuth auth_mode 取值:OAuth 浏览器登录。空串/其他值一律视为 AK/SK 签名模式。 +const AuthModeOAuth = "oauth" + +// TokenExpirySkew 主动刷新的时钟偏斜余量(D6) +const TokenExpirySkew = 5 * time.Minute + +// GenerateState 生成 CSRF state:32 字节随机 base64url +func GenerateState() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", fmt.Errorf("generate state failed: %v", err) + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// GetOAuthBaseURL 生效的 OAuth 域名:profile 配置优先,否则内置默认(D9.2) +func GetOAuthBaseURL(cfg *AggConfig) (string, error) { + if cfg.OAuthBaseURL != "" { + return strings.TrimSuffix(cfg.OAuthBaseURL, "/"), nil + } + return defaultOAuthBaseURL, nil +} + +// BuildAuthorizeURL 拼授权 URL(流程步骤①) +func BuildAuthorizeURL(oauthBase, redirectURI, state string) string { + v := url.Values{} + v.Set("response_type", "code") + v.Set("client_id", oauthClientID) + v.Set("redirect_uri", redirectURI) + v.Set("scope", oauthScope) + v.Set("state", state) + return fmt.Sprintf("%s/authorize?%s", oauthBase, v.Encode()) +} + +// SanitizeCallbackInput 容忍前后空白/引号/终端折行引入的内部空白与换行(D7 输入容错) +func SanitizeCallbackInput(input string) string { + s := strings.TrimSpace(input) + s = strings.Trim(s, `"'`) + return strings.Map(func(r rune) rune { + switch r { + case '\n', '\r', ' ', '\t': + return -1 + } + return r + }, s) +} + +const callbackFormatHint = "expected format: http://localhost/authorization?code=xxx&state=yyy" + +// ParseCallbackURL 校验 state 并提取 code(流程步骤③) +func ParseCallbackURL(input, expectState string) (string, error) { + s := SanitizeCallbackInput(input) + u, err := url.Parse(s) + if err != nil { + return "", fmt.Errorf("cannot parse the pasted URL, no authorization code found; %s", callbackFormatHint) + } + q := u.Query() + if e := q.Get("error"); e != "" { + if e == "access_denied" { + return "", fmt.Errorf("authorization was denied in the browser. Run 'ucloud auth login' to try again") + } + return "", fmt.Errorf("oauth server returned error %q. Run 'ucloud auth login' to try again", e) + } + code := q.Get("code") + if code == "" { + return "", fmt.Errorf("no authorization code in the pasted URL; %s", callbackFormatHint) + } + if q.Get("state") != expectState { + return "", fmt.Errorf("state mismatch: the pasted URL likely comes from a previous login attempt. Run 'ucloud auth login' again and paste the URL from THIS attempt") + } + return code, nil +} + +// TokenExpiredAt 判断 token 是否需要刷新(留 TokenExpirySkew 余量) +func TokenExpiredAt(expiresAt int64, now time.Time) bool { + if expiresAt == 0 { + return true + } + return now.Add(TokenExpirySkew).Unix() >= expiresAt +} + +// TokenExpired TokenExpiredAt 的当前时间封装 +func TokenExpired(expiresAt int64) bool { + return TokenExpiredAt(expiresAt, time.Now()) +} + +// ParseIDTokenEmail 解 id_token payload 取 email。不验签,仅用于 UI 展示(D2 知情裁定);id_token 不落盘。 +func ParseIDTokenEmail(idToken string) (string, error) { + parts := strings.Split(idToken, ".") + if len(parts) != 3 { + return "", fmt.Errorf("malformed id_token") + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + // 部分 OIDC 实现会输出带 '=' 填充的 base64url,去填充后重试一次 + payload, err = base64.RawURLEncoding.DecodeString(strings.TrimRight(parts[1], "=")) + if err != nil { + return "", fmt.Errorf("decode id_token payload failed: %v", err) + } + } + var claims struct { + Email string `json:"email"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return "", fmt.Errorf("parse id_token payload failed: %v", err) + } + return claims.Email, nil +} + +// redactPatterns 覆盖 query 参数、JSON 字段、HTTP 头三种形态的敏感值 +var redactPatterns = []*regexp.Regexp{ + regexp.MustCompile(`(?i)((?:^|[?&\s])code=)[^&\s"']+`), + regexp.MustCompile(`(?i)((?:^|[?&\s])state=)[^&\s"']+`), + regexp.MustCompile(`(?i)(access_token"?\s*[:=]\s*"?)[^,}&\s"']+`), + regexp.MustCompile(`(?i)(refresh_token"?\s*[:=]\s*"?)[^,}&\s"']+`), + regexp.MustCompile(`(?i)(id_token"?\s*[:=]\s*"?)[^,}&\s"']+`), + regexp.MustCompile(`(?i)(authorization:?\s*bearer\s+)\S+`), +} + +// Redact 脱敏 code/token/authorization(D7 最小脱敏,UC1 提前到 Phase 1) +func Redact(s string) string { + for _, p := range redactPatterns { + s = p.ReplaceAllString(s, "${1}********") + } + return s +} + +// IsStdinTTY 判断 stdin 是否为交互终端(AP-1)。 +// 不能用 os.ModeCharDevice:/dev/null 也是字符设备,cron/CI 重定向会被误判为交互。 +// go-isatty 走真实终端检查(unix ioctl / windows console API),Cygwin/mintty 下 stdin 是管道,单独判。 +func IsStdinTTY() bool { + fd := os.Stdin.Fd() + return isatty.IsTerminal(fd) || isatty.IsCygwinTerminal(fd) +} + +// OAuthLoginRequiredHint oauth 模式但 token 缺失时的提示(AP-1/AP-3,走 stderr) +func OAuthLoginRequiredHint(profile string, isTTY bool) string { + if isTTY { + return fmt.Sprintf("Profile '%s' uses OAuth login but has no token. Run 'ucloud auth login' first", profile) + } + return fmt.Sprintf("Profile '%s' uses OAuth login, which cannot work in a non-interactive environment. For automation/CI, use an AK/SK profile: ucloud config --profile --public-key --private-key ", profile) +} + +// OAuthRefreshFailedHint refresh_token 失效/刷新失败时的提示(AP-3 模板) +func OAuthRefreshFailedHint(profile string, isTTY bool, err error) string { + if isTTY { + return fmt.Sprintf("Login expired for profile '%s' (%s). Run 'ucloud auth login' again", profile, Redact(err.Error())) + } + return fmt.Sprintf("OAuth login for profile '%s' cannot be renewed in a non-interactive environment (%s). For unattended scenarios, use an AK/SK profile instead", profile, Redact(err.Error())) +} + +// CheckOAuthRunnable oauth 模式启动检查;ok=false 时调用方应将 msg 输出到 stderr 并以非零码退出。 +// 此处刻意忽略 ExpiresAt——过期由 EnsureFreshToken(Task 6)处理,本函数只检查 token 是否存在。 +func CheckOAuthRunnable(cfg *AggConfig, isTTY bool) (string, bool) { + if cfg.AccessToken == "" || cfg.RefreshToken == "" { + return OAuthLoginRequiredHint(cfg.Profile, isTTY), false + } + return "", true +} + +// TokenResponse /token 端点响应(流程步骤④) +type TokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token"` + ExpiresIn int64 `json:"expires_in"` + TokenType string `json:"token_type"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +// oauthHTTPClient 使用默认 Transport:自动遵守 HTTPS_PROXY/HTTP_PROXY/NO_PROXY(ProxyFromEnvironment) +var oauthHTTPClient = &http.Client{Timeout: 30 * time.Second} + +func requestToken(oauthBase string, form url.Values) (*TokenResponse, error) { + endpoint := strings.TrimSuffix(oauthBase, "/") + "/token" + resp, err := oauthHTTPClient.PostForm(endpoint, form) + if err != nil { + return nil, fmt.Errorf("cannot reach oauth server %s (check network or proxy settings): %v", endpoint, err) + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read oauth server response failed: %v", err) + } + var tr TokenResponse + if jerr := json.Unmarshal(body, &tr); jerr != nil { + if resp.StatusCode >= 500 { + return nil, fmt.Errorf("oauth server error (HTTP %d), retry later", resp.StatusCode) + } + return nil, fmt.Errorf("unexpected oauth server response (HTTP %d): %s", resp.StatusCode, Redact(string(body))) + } + if tr.Error != "" { + return nil, translateOAuthError(tr.Error, tr.ErrorDescription) + } + if resp.StatusCode >= 500 { + return nil, fmt.Errorf("oauth server error (HTTP %d), retry later", resp.StatusCode) + } + if tr.AccessToken == "" { + return nil, fmt.Errorf("oauth server returned no access_token (HTTP %d)", resp.StatusCode) + } + return &tr, nil +} + +// translateOAuthError 按 AP-3 模板翻译 OAuth 错误码:原因 + 下一步命令 +func translateOAuthError(code, desc string) error { + switch code { + case "invalid_grant": + return fmt.Errorf("authorization code or refresh token expired or already used (each code works only once). Run 'ucloud auth login' again and paste the URL promptly") + case "access_denied": + return fmt.Errorf("authorization was denied. Run 'ucloud auth login' to try again") + default: + return fmt.Errorf("oauth server rejected the request: %s (%s). Run 'ucloud auth login' to start over", code, Redact(desc)) + } +} + +// ExchangeToken 授权码换 token(流程步骤④) +func ExchangeToken(oauthBase, redirectURI, code string) (*TokenResponse, error) { + form := url.Values{} + form.Set("grant_type", "authorization_code") + form.Set("code", code) + form.Set("client_id", oauthClientID) + form.Set("client_secret", oauthClientSecret) + form.Set("redirect_uri", redirectURI) + return requestToken(oauthBase, form) +} + +// RefreshToken 刷新 access_token;响应中的新 refresh_token 表示轮换(D3:旧的立即作废,必须写回) +func RefreshToken(oauthBase, refreshToken string) (*TokenResponse, error) { + form := url.Values{} + form.Set("grant_type", "refresh_token") + form.Set("refresh_token", refreshToken) + form.Set("client_id", oauthClientID) + form.Set("client_secret", oauthClientSecret) + return requestToken(oauthBase, form) +} + +// ApplyTokenResponse 把 /token 响应写入 cfg;轮换语义:响应带新 refresh_token 则覆盖(D3) +func ApplyTokenResponse(cfg *AggConfig, tr *TokenResponse) { + cfg.AuthMode = AuthModeOAuth + cfg.AccessToken = tr.AccessToken + if tr.RefreshToken != "" { + cfg.RefreshToken = tr.RefreshToken + } + cfg.ExpiresAt = time.Now().Unix() + tr.ExpiresIn +} + +// EnsureFreshToken 主动刷新(D6):过期(含 5min 偏斜余量)则 refresh 并写回。 +// 「刷新+写回」由 refreshAndSave 内的 flock 串行化,拿锁后重读磁盘(Task 11)。 +func EnsureFreshToken(cfg *AggConfig, manager *AggConfigManager) error { + if !TokenExpired(cfg.ExpiresAt) { + return nil + } + return refreshAndSave(cfg, manager) +} + +// credentialLockPath flock 锁文件路径;包级变量便于测试注入 +var credentialLockPath = "" + +func getCredentialLockPath() string { + if credentialLockPath != "" { + return credentialLockPath + } + return GetConfigDir() + "/credential.lock" +} + +// credentialLockTimeout 拿锁超时(D3:超时明确报错) +const credentialLockTimeout = 10 * time.Second + +// refreshAndSave 串行化「刷新+写回」临界区(D3/D9.4): +// flock 跨进程互斥 → 拿锁后重读磁盘(他进程可能已刷新并轮换)→ 仍过期才真正刷新。 +func refreshAndSave(cfg *AggConfig, manager *AggConfigManager) error { + staleToken := cfg.AccessToken + + fl := flock.New(getCredentialLockPath()) + ctx, cancel := context.WithTimeout(context.Background(), credentialLockTimeout) + defer cancel() + ok, err := fl.TryLockContext(ctx, 200*time.Millisecond) + if err != nil && !errors.Is(err, context.DeadlineExceeded) { + // 硬错误(如锁文件无权限),与拿锁超时是两回事,必须带上原始错误 + return fmt.Errorf("acquire credential lock %s failed: %v", getCredentialLockPath(), err) + } + if !ok { + return fmt.Errorf("timed out acquiring credential lock %s after %v: another ucloud process may be refreshing, retry later", getCredentialLockPath(), credentialLockTimeout) + } + defer fl.Unlock() + + // 拿锁后重读:他进程已刷新则直接采用,避免用已作废的 refresh_token 二次刷新 + if disk, derr := readCredentialFromDisk(manager.credPath, cfg.Profile); derr == nil && disk != nil { + if disk.AccessToken != "" && disk.AccessToken != staleToken && !TokenExpired(disk.ExpiresAt) { + cfg.AccessToken = disk.AccessToken + cfg.RefreshToken = disk.RefreshToken + cfg.ExpiresAt = disk.ExpiresAt + cfg.AuthMode = AuthModeOAuth + return nil + } + if disk.RefreshToken != "" { + cfg.RefreshToken = disk.RefreshToken // 轮换后的最新 refresh_token 以磁盘为准 + } + } + + oauthBase, err := GetOAuthBaseURL(cfg) + if err != nil { + return err + } + tr, err := RefreshToken(oauthBase, cfg.RefreshToken) + if err != nil { + return err + } + ApplyTokenResponse(cfg, tr) + // Save() 会把内存里全部 profile 整写落盘,而本进程内存是 t0 快照:他进程可能已在 + // t0 之后轮换了其它 profile 的 refresh_token(D3 旧的立即作废)。落盘前重读磁盘, + // 把「非当前 profile」的 oauth 字段以磁盘为准合并,否则会把轮换结果覆盖回陈旧值, + // 导致对方 profile 下次刷新 invalid_grant(被迫重新登录)。 + if creds, rerr := readAllCredentialsFromDisk(manager.credPath); rerr == nil { + mergeOtherProfilesOAuthFromDisk(manager.configs, creds, cfg.Profile) + } + if err := manager.Save(); err != nil { + return fmt.Errorf("token refreshed but saving credential failed: %v", err) + } + return nil +} + +// mergeOtherProfilesOAuthFromDisk 把磁盘版凭据中「非当前 profile」的 oauth 四字段 +// (auth_mode/access_token/refresh_token/expires_at)合并进内存。只合并这四个字段: +// flock 临界区内唯一的合法并发写就是 oauth 刷新轮换,AK/SK、cookie 等字段不受锁保护、 +// 不在此处静默采纳。当前 profile 保持本次刷新后的内存值。 +func mergeOtherProfilesOAuthFromDisk(configs map[string]*AggConfig, diskCreds []CredentialConfig, currentProfile string) { + for i := range diskCreds { + dc := &diskCreds[i] + if dc.Profile == currentProfile { + continue + } + ac, ok := configs[dc.Profile] + if !ok { + continue + } + ac.AuthMode = dc.AuthMode + ac.AccessToken = dc.AccessToken + ac.RefreshToken = dc.RefreshToken + ac.ExpiresAt = dc.ExpiresAt + } +} + +// readAllCredentialsFromDisk 重新读盘取全部 profile 的最新凭据(不经 manager 缓存) +func readAllCredentialsFromDisk(credPath string) ([]CredentialConfig, error) { + raw, err := ioutil.ReadFile(credPath) + if err != nil { + return nil, err + } + if len(raw) == 0 { + return nil, nil + } + var creds []CredentialConfig + if err := json.Unmarshal(raw, &creds); err != nil { + return nil, err + } + return creds, nil +} + +// readCredentialFromDisk 重新读盘取指定 profile 的最新凭据(不经 manager 缓存) +func readCredentialFromDisk(credPath, profile string) (*CredentialConfig, error) { + creds, err := readAllCredentialsFromDisk(credPath) + if err != nil { + return nil, err + } + for i := range creds { + if creds[i].Profile == profile { + return &creds[i], nil + } + } + return nil, nil +} diff --git a/base/oauth_http_test.go b/base/oauth_http_test.go new file mode 100644 index 0000000000..2c6e1c6b44 --- /dev/null +++ b/base/oauth_http_test.go @@ -0,0 +1,119 @@ +// base/oauth_http_test.go +package base + +import ( + "fmt" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func tokenServer(t *testing.T, status int, body string, gotForm *map[string]string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/token" { + t.Errorf("unexpected path %s", r.URL.Path) + } + r.ParseForm() + if gotForm != nil { + m := map[string]string{} + for k := range r.PostForm { + m[k] = r.PostForm.Get(k) + } + *gotForm = m + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + fmt.Fprint(w, body) + })) +} + +// 换 token 3 分支:成功 / invalid_grant / 5xx +func TestExchangeToken(t *testing.T) { + t.Run("success", func(t *testing.T) { + var form map[string]string + s := tokenServer(t, 200, `{"access_token":"at1","refresh_token":"rt1","id_token":"idt","expires_in":3600,"token_type":"Bearer"}`, &form) + defer s.Close() + tr, err := ExchangeToken(s.URL, "http://localhost:8723/authorization", "code1") + if err != nil { + t.Fatal(err) + } + if tr.AccessToken != "at1" || tr.RefreshToken != "rt1" || tr.ExpiresIn != 3600 { + t.Errorf("unexpected token response: %+v", tr) + } + if form["grant_type"] != "authorization_code" || form["code"] != "code1" || form["redirect_uri"] != "http://localhost:8723/authorization" { + t.Errorf("bad form: %v", form) + } + }) + t.Run("invalid_grant translated", func(t *testing.T) { + s := tokenServer(t, 400, `{"error":"invalid_grant","error_description":"code expired"}`, nil) + defer s.Close() + _, err := ExchangeToken(s.URL, "http://localhost:8723/authorization", "old") + if err == nil || !strings.Contains(err.Error(), "ucloud auth login") || !strings.Contains(err.Error(), "expired or already used") { + t.Errorf("invalid_grant should translate to actionable message, got %v", err) + } + }) + t.Run("default error redacts description", func(t *testing.T) { + s := tokenServer(t, 400, `{"error":"invalid_request","error_description":"authorization: Bearer sk-secret-token-123"}`, nil) + defer s.Close() + _, err := ExchangeToken(s.URL, "http://localhost:8723/authorization", "c") + if err == nil || !strings.Contains(err.Error(), "rejected the request") { + t.Fatalf("default branch should surface rejection, got %v", err) + } + if strings.Contains(err.Error(), "sk-secret-token-123") { + t.Errorf("error_description must be redacted, got %v", err) + } + }) + t.Run("server 5xx", func(t *testing.T) { + s := tokenServer(t, 500, `oops`, nil) + defer s.Close() + _, err := ExchangeToken(s.URL, "http://localhost:8723/authorization", "c") + if err == nil || !strings.Contains(err.Error(), "server error") { + t.Errorf("5xx should say server error + retry, got %v", err) + } + }) +} + +// 刷新 3 分支:成功(轮换) / refresh 失效 / 网络不可达 +func TestRefreshToken(t *testing.T) { + t.Run("success with rotation", func(t *testing.T) { + var form map[string]string + s := tokenServer(t, 200, `{"access_token":"at2","refresh_token":"rt2-rotated","expires_in":3600}`, &form) + defer s.Close() + tr, err := RefreshToken(s.URL, "rt1") + if err != nil { + t.Fatal(err) + } + if tr.RefreshToken != "rt2-rotated" { + t.Errorf("rotated refresh token not surfaced: %+v", tr) + } + if form["grant_type"] != "refresh_token" || form["refresh_token"] != "rt1" { + t.Errorf("bad form: %v", form) + } + }) + t.Run("invalid refresh token", func(t *testing.T) { + s := tokenServer(t, 400, `{"error":"invalid_grant","error_description":"refresh token revoked"}`, nil) + defer s.Close() + if _, err := RefreshToken(s.URL, "dead"); err == nil { + t.Error("expect error for revoked refresh token") + } + }) + t.Run("unreachable", func(t *testing.T) { + _, err := RefreshToken("http://127.0.0.1:1", "rt") + if err == nil || !strings.Contains(err.Error(), "cannot reach oauth server") { + t.Errorf("network error should be distinguished, got %v", err) + } + }) +} + +// 钉死:oauthHTTPClient 必须遵守 HTTPS_PROXY 等代理环境变量(默认 Transport 或显式 ProxyFromEnvironment) +func TestOAuthClientHonorsProxyEnv(t *testing.T) { + if oauthHTTPClient.Transport == nil { + return // nil Transport == http.DefaultTransport,自带 ProxyFromEnvironment + } + tr, ok := oauthHTTPClient.Transport.(*http.Transport) + if !ok || tr.Proxy == nil { + t.Error("oauthHTTPClient custom transport must set Proxy: http.ProxyFromEnvironment") + } +} diff --git a/base/oauth_refresh_test.go b/base/oauth_refresh_test.go new file mode 100644 index 0000000000..112c23a1dc --- /dev/null +++ b/base/oauth_refresh_test.go @@ -0,0 +1,217 @@ +// base/oauth_refresh_test.go +package base + +import ( + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync" + "testing" + "time" +) + +func newTestManager(t *testing.T, ac *AggConfig) *AggConfigManager { + t.Helper() + os.MkdirAll(".ucloud", 0700) + t.Cleanup(func() { os.RemoveAll(".ucloud") }) + credentialLockPath = ".ucloud/credential.lock" + t.Cleanup(func() { credentialLockPath = "" }) + m, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") + if err != nil { + t.Fatal(err) + } + if err := m.Append(ac); err != nil { + t.Fatal(err) + } + return m +} + +func TestApplyTokenResponse(t *testing.T) { + cfg := &AggConfig{Profile: "p", RefreshToken: "old-rt"} + ApplyTokenResponse(cfg, &TokenResponse{AccessToken: "at", ExpiresIn: 3600}) + if cfg.AuthMode != AuthModeOAuth || cfg.AccessToken != "at" { + t.Errorf("token not applied: %+v", cfg) + } + if cfg.RefreshToken != "old-rt" { + t.Error("empty refresh_token in response must keep the old one") + } + if cfg.ExpiresAt < time.Now().Unix()+3500 || cfg.ExpiresAt > time.Now().Unix()+3700 { + t.Errorf("expires_at wrong: %d", cfg.ExpiresAt) + } + // 轮换:新 refresh_token 覆盖旧(D3) + ApplyTokenResponse(cfg, &TokenResponse{AccessToken: "at2", RefreshToken: "new-rt", ExpiresIn: 3600}) + if cfg.RefreshToken != "new-rt" { + t.Error("rotated refresh_token must overwrite") + } +} + +func TestEnsureFreshToken(t *testing.T) { + refreshCalls := 0 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + refreshCalls++ + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"new-at","refresh_token":"new-rt","expires_in":3600}`) + })) + defer s.Close() + + t.Run("expired triggers refresh and persists", func(t *testing.T) { + ac := &AggConfig{ + Profile: "p1", Active: true, BaseURL: DefaultBaseURL, Timeout: 15, MaxRetryTimes: intPtr(3), + AuthMode: AuthModeOAuth, AccessToken: "old-at", RefreshToken: "old-rt", + ExpiresAt: time.Now().Unix() - 100, OAuthBaseURL: s.URL, + } + m := newTestManager(t, ac) + if err := EnsureFreshToken(ac, m); err != nil { + t.Fatal(err) + } + if ac.AccessToken != "new-at" || ac.RefreshToken != "new-rt" { + t.Errorf("token not refreshed in memory: %+v", ac) + } + raw, _ := ioutil.ReadFile(".ucloud/credential.json") + if !strings.Contains(string(raw), "new-rt") { + t.Errorf("rotated refresh token not persisted: %s", raw) + } + }) + + t.Run("fresh token skips refresh", func(t *testing.T) { + before := refreshCalls + ac := &AggConfig{ + Profile: "p2", Active: true, BaseURL: DefaultBaseURL, Timeout: 15, MaxRetryTimes: intPtr(3), + AuthMode: AuthModeOAuth, AccessToken: "at", RefreshToken: "rt", + ExpiresAt: time.Now().Add(time.Hour).Unix(), OAuthBaseURL: s.URL, + } + m := newTestManager(t, ac) + if err := EnsureFreshToken(ac, m); err != nil { + t.Fatal(err) + } + if refreshCalls != before { + t.Error("fresh token must not hit /token") + } + }) +} + +// 跨 profile 凭据保护:进程 A(t0 加载 X/Y)刷新 Y 落盘时,不得用内存里的陈旧 X +// 覆盖他进程 B(t1)已轮换写盘的 X 凭据——否则 X 下次刷新必 invalid_grant(D3 旧 refresh_token 立即作废)。 +func TestRefreshAndSaveKeepsOtherProfilesRotatedTokens(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + fmt.Fprint(w, `{"access_token":"y-new-at","refresh_token":"y-new-rt","expires_in":3600}`) + })) + defer s.Close() + + // t0:进程 A 加载两个 oauth profile(X 未过期,Y 已过期) + acX := &AggConfig{ + Profile: "px", Active: true, BaseURL: DefaultBaseURL, Timeout: 15, MaxRetryTimes: intPtr(3), + AuthMode: AuthModeOAuth, AccessToken: "x-old-at", RefreshToken: "x-old-rt", + ExpiresAt: time.Now().Add(time.Hour).Unix(), + } + m := newTestManager(t, acX) + acY := &AggConfig{ + Profile: "py", BaseURL: DefaultBaseURL, Timeout: 15, MaxRetryTimes: intPtr(3), + AuthMode: AuthModeOAuth, AccessToken: "y-old-at", RefreshToken: "y-old-rt", + ExpiresAt: time.Now().Unix() - 100, OAuthBaseURL: s.URL, + } + if err := m.Append(acY); err != nil { + t.Fatal(err) + } + + // t1:模拟进程 B 刷新 X 并轮换 refresh_token,直接写盘(不经 A 的 manager) + raw, err := ioutil.ReadFile(".ucloud/credential.json") + if err != nil { + t.Fatal(err) + } + var creds []CredentialConfig + if err := json.Unmarshal(raw, &creds); err != nil { + t.Fatal(err) + } + for i := range creds { + if creds[i].Profile == "px" { + creds[i].AccessToken = "x-rotated-at" + creds[i].RefreshToken = "x-rotated-rt" + creds[i].ExpiresAt = time.Now().Add(2 * time.Hour).Unix() + } + } + out, err := json.Marshal(creds) + if err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(".ucloud/credential.json", out, 0600); err != nil { + t.Fatal(err) + } + + // t2:进程 A 刷新 Y 并 Save + if err := EnsureFreshToken(acY, m); err != nil { + t.Fatal(err) + } + + diskX, err := readCredentialFromDisk(".ucloud/credential.json", "px") + if err != nil || diskX == nil { + t.Fatalf("reload px from disk failed: %v (%v)", diskX, err) + } + if diskX.AccessToken != "x-rotated-at" || diskX.RefreshToken != "x-rotated-rt" { + t.Errorf("process B's rotated X tokens were overwritten by A's stale copy: access=%s refresh=%s", + diskX.AccessToken, diskX.RefreshToken) + } + diskY, err := readCredentialFromDisk(".ucloud/credential.json", "py") + if err != nil || diskY == nil { + t.Fatalf("reload py from disk failed: %v (%v)", diskY, err) + } + if diskY.AccessToken != "y-new-at" || diskY.RefreshToken != "y-new-rt" { + t.Errorf("Y's refreshed tokens not persisted: access=%s refresh=%s", diskY.AccessToken, diskY.RefreshToken) + } +} + +// 并发刷新仅一次轮换(D3):两个并发 EnsureFreshToken 只允许打一次 /token +func TestConcurrentRefreshSingleRotation(t *testing.T) { + var mu sync.Mutex + refreshCalls := 0 + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + mu.Lock() + refreshCalls++ + n := refreshCalls + mu.Unlock() + time.Sleep(100 * time.Millisecond) // 放大竞争窗口 + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"access_token":"at-%d","refresh_token":"rt-%d","expires_in":3600}`, n, n) + })) + defer s.Close() + + ac := &AggConfig{ + Profile: "pc", Active: true, BaseURL: DefaultBaseURL, Timeout: 15, MaxRetryTimes: intPtr(3), + AuthMode: AuthModeOAuth, AccessToken: "old-at", RefreshToken: "old-rt", + ExpiresAt: time.Now().Unix() - 100, OAuthBaseURL: s.URL, + } + m := newTestManager(t, ac) + + // 模拟两个进程:各自持有独立的 AggConfig 副本与 manager 视图。 + // 注意必须用独立 manager(m2):若复用 m,ac2 的刷新结果不在 m.configs 中, + // Save() 不会落盘,磁盘重读看到的仍是旧凭据,测不出真实的跨进程行为。 + m2, err := NewAggConfigManager(".ucloud/config.json", ".ucloud/credential.json") + if err != nil { + t.Fatal(err) + } + ac2, ok := m2.GetAggConfigByProfile("pc") + if !ok { + t.Fatal("profile pc not loaded by second manager") + } + + var wg sync.WaitGroup + errs := make([]error, 2) + wg.Add(2) + go func() { defer wg.Done(); errs[0] = EnsureFreshToken(ac, m) }() + go func() { defer wg.Done(); errs[1] = EnsureFreshToken(ac2, m2) }() + wg.Wait() + + for i, err := range errs { + if err != nil { + t.Errorf("refresher %d failed: %v", i, err) + } + } + if refreshCalls != 1 { + t.Errorf("expect exactly 1 rotation, got %d", refreshCalls) + } +} diff --git a/base/oauth_test.go b/base/oauth_test.go new file mode 100644 index 0000000000..4d19aabb07 --- /dev/null +++ b/base/oauth_test.go @@ -0,0 +1,171 @@ +// base/oauth_test.go +package base + +import ( + "os" + "strings" + "testing" + "time" +) + +// /dev/null 是字符设备但不是终端;cron/CI 常用 `ucloud xxx =1024),返回 listener 与端口。 +func allocateLoopbackListener() (net.Listener, int, error) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + return nil, 0, fmt.Errorf("cannot open a local callback port: %v", err) + } + port := ln.Addr().(*net.TCPAddr).Port + return ln, port, nil +} + +type callbackResult struct { + code string + err error +} + +const callbackSuccessHTML = ` + +Login successful + +

Login successful

+

You can close this tab and return to the terminal.

+

登录成功,可关闭此页面返回终端。

+ +` + +// startCallbackServer 在给定 listener 上起一个临时 HTTP server,只处理 GET /authorization。 +// 结果通过返回的 channel(缓冲1)投递,sync.Once 保证只投递一次。投递规则: +// - error 参数(如 access_denied)→ 回 400 并投递错误(中止登录); +// - code + state 匹配 → 回成功页并投递 code; +// - 缺 code 或 state 不匹配(本地探针/陈旧标签页等噪音请求)→ 仅回 400、不投递, +// 继续等待真正的回调(上层 loginCallbackTimeout 兜底)。 +func startCallbackServer(ln net.Listener, expectState string) (*http.Server, <-chan callbackResult) { + ch := make(chan callbackResult, 1) + var once sync.Once + + srv := &http.Server{} + mux := http.NewServeMux() + mux.HandleFunc("/authorization", func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + if e := q.Get("error"); e != "" { + var err error + if e == "access_denied" { + err = fmt.Errorf("authorization was denied in the browser. Run 'ucloud auth login' to try again") + } else { + err = fmt.Errorf("oauth server returned error %q. Run 'ucloud auth login' to try again", e) + } + http.Error(w, "Login failed. Return to the terminal for details.", http.StatusBadRequest) + once.Do(func() { + ch <- callbackResult{err: err} + }) + return + } + + code := q.Get("code") + if code == "" || q.Get("state") != expectState { + // 噪音请求:不消耗 once,登录继续等待真正的回调 + http.Error(w, "Login failed. Return to the terminal for details.", http.StatusBadRequest) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, callbackSuccessHTML) + once.Do(func() { + ch <- callbackResult{code: code} + }) + }) + srv.Handler = mux + + go srv.Serve(ln) + return srv, ch +} diff --git a/cmd/callback_test.go b/cmd/callback_test.go new file mode 100644 index 0000000000..577ce131b8 --- /dev/null +++ b/cmd/callback_test.go @@ -0,0 +1,126 @@ +// cmd/callback_test.go +package cmd + +import ( + "fmt" + "net/http" + "testing" + "time" +) + +// setupCallback 在已分配 listener 上起 callback server,返回端口与结果 channel。 +func setupCallback(t *testing.T, expectState string) (int, <-chan callbackResult) { + t.Helper() + ln, port, err := allocateLoopbackListener() + if err != nil { + t.Fatalf("allocate listener: %v", err) + } + srv, ch := startCallbackServer(ln, expectState) + t.Cleanup(func() { srv.Close() }) + return port, ch +} + +// get 向 callback server 发一个回调请求,返回 HTTP 状态码。 +func get(t *testing.T, port int, query string) int { + t.Helper() + url := fmt.Sprintf("http://127.0.0.1:%d/authorization?%s", port, query) + resp, err := http.Get(url) + if err != nil { + t.Fatalf("GET callback: %v", err) + } + resp.Body.Close() + return resp.StatusCode +} + +// drive 起 callback server,发一个回调请求,返回投递的结果(仅用于必然投递的场景)。 +func drive(t *testing.T, expectState, query string) callbackResult { + t.Helper() + port, ch := setupCallback(t, expectState) + get(t, port, query) + + select { + case res := <-ch: + return res + case <-time.After(3 * time.Second): + t.Fatal("callback result not delivered") + return callbackResult{} + } +} + +// assertNoDelivery 断言 channel 在短窗口内保持为空(噪音请求不得投递)。 +func assertNoDelivery(t *testing.T, ch <-chan callbackResult) { + t.Helper() + select { + case res := <-ch: + t.Fatalf("noise request must not deliver a result, got %+v", res) + case <-time.After(100 * time.Millisecond): + } +} + +func TestCallbackSuccess(t *testing.T) { + res := drive(t, "st", "code=abc&state=st") + if res.err != nil { + t.Fatalf("expected success, got err %v", res.err) + } + if res.code != "abc" { + t.Errorf("code = %q, want abc", res.code) + } +} + +// 噪音请求(state 不匹配,如陈旧标签页的旧回调):回 400 但不投递,继续等待真正的回调。 +func TestCallbackStateMismatch(t *testing.T) { + port, ch := setupCallback(t, "st") + if status := get(t, port, "code=x&state=WRONG"); status != http.StatusBadRequest { + t.Errorf("state-mismatch noise status = %d, want 400", status) + } + assertNoDelivery(t, ch) + + // 同一 server 上真正的回调仍然成功 + if status := get(t, port, "code=real&state=st"); status != http.StatusOK { + t.Errorf("genuine callback status = %d, want 200", status) + } + select { + case res := <-ch: + if res.err != nil { + t.Fatalf("genuine callback after noise: unexpected err %v", res.err) + } + if res.code != "real" { + t.Errorf("code = %q, want real", res.code) + } + case <-time.After(3 * time.Second): + t.Fatal("genuine callback not delivered after noise request") + } +} + +// 噪音请求(有 state 无 code,如本地探针):回 400 但不投递,继续等待真正的回调。 +func TestCallbackStateWithoutCode(t *testing.T) { + port, ch := setupCallback(t, "st") + if status := get(t, port, "state=st"); status != http.StatusBadRequest { + t.Errorf("missing-code noise status = %d, want 400", status) + } + assertNoDelivery(t, ch) + + // 同一 server 上真正的回调仍然成功 + if status := get(t, port, "code=abc&state=st"); status != http.StatusOK { + t.Errorf("genuine callback status = %d, want 200", status) + } + select { + case res := <-ch: + if res.err != nil { + t.Fatalf("genuine callback after noise: unexpected err %v", res.err) + } + if res.code != "abc" { + t.Errorf("code = %q, want abc", res.code) + } + case <-time.After(3 * time.Second): + t.Fatal("genuine callback not delivered after noise request") + } +} + +// error 参数(用户在浏览器里拒绝授权)是明确的失败信号:必须投递并中止登录。 +func TestCallbackAccessDenied(t *testing.T) { + res := drive(t, "st", "error=access_denied&state=st") + if res.err == nil { + t.Fatal("expected access_denied error") + } +} diff --git a/cmd/configure.go b/cmd/configure.go index fdc5682ed0..1af95dedb8 100644 --- a/cmd/configure.go +++ b/cmd/configure.go @@ -37,8 +37,7 @@ const helloUcloud = ` | | | | __/ | | (_) | | |_| | \__/\ | (_) | |_| | (_| | \_| |_/\___|_|_|\___/ \___/ \____/_|\___/ \__,_|\__,_| -If you want add or modify your configurations, run 'ucloud config add/update' -` +If you want add or modify your configurations, run 'ucloud config add/update'` // NewCmdInit ucloud init func NewCmdInit() *cobra.Command { @@ -47,7 +46,22 @@ func NewCmdInit() *cobra.Command { Short: "Initialize UCloud CLI options", Long: `Initialize UCloud CLI options such as private-key,public-key,default region,zone and project.`, Run: func(cmd *cobra.Command, args []string) { + fromOAuth := base.ConfigIns.AuthMode == base.AuthModeOAuth + if fromOAuth { + ok := base.Confirm(false, fmt.Sprintf("Profile '%s' currently uses OAuth login (auth_mode=oauth). Continue with AK/SK setup and switch this profile to key-based auth? (y/n):", base.ConfigIns.Profile)) + if !ok { + return + } + clearOAuthState(base.ConfigIns) + } + if base.ConfigIns.PrivateKey != "" && base.ConfigIns.PublicKey != "" { + if fromOAuth { + if err := switchProfileToAKSK(base.ConfigIns); err != nil { + base.HandleError(err) + return + } + } printHello() return } @@ -92,7 +106,7 @@ func NewCmdInit() *cobra.Command { fmt.Printf("Active profile name:%s\n", base.ConfigIns.Profile) fmt.Println("You can change the default settings by running 'ucloud config update'") base.ConfigIns.ConfigUploadLog() - err = base.AggConfigListIns.Append(base.ConfigIns) + err = saveInitProfile(base.ConfigIns) if err != nil { base.HandleError(fmt.Errorf("Error: %v", err)) } else { @@ -104,6 +118,26 @@ func NewCmdInit() *cobra.Command { return cmd } +// saveInitProfile 持久化 init 完整配置流程的结果;profile 已存在时(OAuth-only profile +// 切回 AK/SK 的场景)覆盖保存——依赖 ConfigIns 即 manager map 内的同一指针(InitConfig 保证) +func saveInitProfile(cfg *base.AggConfig) error { + return base.AggConfigListIns.UpdateAggConfig(cfg) +} + +// clearOAuthState 清除 profile 的 oauth 状态(口径与 'ucloud auth logout' 一致),不落盘 +func clearOAuthState(cfg *base.AggConfig) { + cfg.AuthMode = "" + cfg.AccessToken = "" + cfg.RefreshToken = "" + cfg.ExpiresAt = 0 +} + +// switchProfileToAKSK 把 OAuth profile 切回 AK/SK:清除 oauth 状态并落盘 +func switchProfileToAKSK(cfg *base.AggConfig) error { + clearOAuthState(cfg) + return base.AggConfigListIns.UpdateAggConfig(cfg) +} + func printHello() { userInfo, err := getUserInfo() if err != nil { @@ -437,32 +471,11 @@ func NewCmdConfigUpdate() *cobra.Command { base.AggConfigListIns.UpdateAggConfig(cacheConfig) } - //如有设置Region和Zone,确保设置的Region和Zone真实存在 - if cfg.Region != "" { - cacheConfig.Region = cfg.Region - } - if cfg.Zone != "" { - cacheConfig.Zone = cfg.Zone - } - - region, zone, err := getReasonableRegionZone(cacheConfig) - if err != nil { - base.HandleError(err) - return - } - - cacheConfig.Region = region - cacheConfig.Zone = zone - - if cfg.ProjectID != "" { - cacheConfig.ProjectID = base.PickResourceID(cfg.ProjectID) - } - - project, err := getReasonableProject(cacheConfig) - if err != nil { - base.HandleError(err) + //先应用连接类参数(base-url/timeout-sec/max-retry-times),确保接下来的远程校验 + //打到新网关而不是旧的(可能已不可用的)网关,避免旧base-url坏掉后无法改回的死锁 + if cfg.BaseURL != "" { + cacheConfig.BaseURL = cfg.BaseURL } - cacheConfig.ProjectID = project if timeout != "" { seconds, err := strconv.Atoi(timeout) @@ -492,10 +505,33 @@ func NewCmdConfigUpdate() *cobra.Command { return } - if cfg.BaseURL != "" { - cacheConfig.BaseURL = cfg.BaseURL + //如有设置Region和Zone,确保设置的Region和Zone真实存在 + if cfg.Region != "" { + cacheConfig.Region = cfg.Region + } + if cfg.Zone != "" { + cacheConfig.Zone = cfg.Zone + } + + region, zone, err := getReasonableRegionZone(cacheConfig) + if err != nil { + base.HandleError(err) + return } + cacheConfig.Region = region + cacheConfig.Zone = zone + + if cfg.ProjectID != "" { + cacheConfig.ProjectID = base.PickResourceID(cfg.ProjectID) + } + + project, err := getReasonableProject(cacheConfig) + if err != nil { + base.HandleError(err) + } + cacheConfig.ProjectID = project + if active == "true" { cacheConfig.Active = true } else if active == "false" { diff --git a/cmd/configure_test.go b/cmd/configure_test.go new file mode 100644 index 0000000000..508049f3ce --- /dev/null +++ b/cmd/configure_test.go @@ -0,0 +1,205 @@ +package cmd + +import ( + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + + "github.com/ucloud/ucloud-cli/base" +) + +// 回归:oauth profile 已存 AK/SK 时(auth login 保留密钥的常见形态), +// init 确认切回 AK/SK 后必须把 auth_mode/token 清除并落盘,否则下次启动仍走 OAuth +func TestSwitchProfileToAKSKPersistsToDisk(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + credPath := filepath.Join(dir, "credential.json") + cliJSON := `[{"profile":"oa","active":true,"region":"cn-bj2","zone":"cn-bj2-04","base_url":"https://api.ucloud.cn/","timeout_sec":15,"max_retry_times":3}]` + credJSON := `[{"public_key":"pub","private_key":"pri","profile":"oa","auth_mode":"oauth","access_token":"at","refresh_token":"rt","expires_at":1234567890}]` + if err := ioutil.WriteFile(cfgPath, []byte(cliJSON), base.LocalFileMode); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(credPath, []byte(credJSON), base.LocalFileMode); err != nil { + t.Fatal(err) + } + + m, err := base.NewAggConfigManager(cfgPath, credPath) + if err != nil { + t.Fatal(err) + } + cfg, ok := m.GetAggConfigByProfile("oa") + if !ok { + t.Fatal("profile oa missing") + } + + oldM, oldC := base.AggConfigListIns, base.ConfigIns + base.AggConfigListIns, base.ConfigIns = m, cfg + defer func() { base.AggConfigListIns, base.ConfigIns = oldM, oldC }() + + if err := switchProfileToAKSK(cfg); err != nil { + t.Fatal(err) + } + + // 重新读盘验证持久化,而非只看内存 + m2, err := base.NewAggConfigManager(cfgPath, credPath) + if err != nil { + t.Fatal(err) + } + got, ok := m2.GetAggConfigByProfile("oa") + if !ok { + t.Fatal("profile oa missing after reload") + } + if got.AuthMode != "" || got.AccessToken != "" || got.RefreshToken != "" || got.ExpiresAt != 0 { + t.Errorf("oauth state must be cleared on disk, got auth_mode=%q access_token=%q refresh_token=%q expires_at=%d", + got.AuthMode, got.AccessToken, got.RefreshToken, got.ExpiresAt) + } + if got.PublicKey != "pub" || got.PrivateKey != "pri" { + t.Errorf("AK/SK must survive the switch, got public_key=%q private_key=%q", got.PublicKey, got.PrivateKey) + } +} + +// fakeGatewayServer 模拟业务网关:响应远程校验所需的 GetRegion/GetProjectList +func fakeGatewayServer(t *testing.T) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, _ := io.ReadAll(r.Body) + payload := r.URL.RawQuery + string(body) + w.Header().Set("Content-Type", "application/json") + switch { + case strings.Contains(payload, "GetRegion"): + fmt.Fprint(w, `{"RetCode":0,"Action":"GetRegionResponse","Regions":[{"Region":"cn-bj2","Zone":"cn-bj2-04","IsDefault":true}]}`) + case strings.Contains(payload, "GetProjectList"): + fmt.Fprint(w, `{"RetCode":0,"Action":"GetProjectListResponse","ProjectSet":[{"ProjectId":"org-123","ProjectName":"Default","IsDefault":true}]}`) + default: + fmt.Fprint(w, `{"RetCode":230,"Message":"unexpected action"}`) + } + })) +} + +// 回归:config update --base-url 必须在远程校验(getReasonableRegionZone 等)之前生效, +// 否则旧 base_url 指向坏网关时校验永远打到坏网关,新地址无法保存(鸡生蛋死锁)。 +func TestConfigUpdateAppliesBaseURLBeforeValidation(t *testing.T) { + gateway := fakeGatewayServer(t) + defer gateway.Close() + + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + credPath := filepath.Join(dir, "credential.json") + // 存量 base_url 指向必然连不通的地址,复现坏网关现场 + cliJSON := `[{"profile":"up","active":true,"project_id":"org-123","region":"cn-bj2","zone":"cn-bj2-04","base_url":"http://127.0.0.1:1/","timeout_sec":3,"max_retry_times":0}]` + credJSON := `[{"public_key":"pub","private_key":"pri","profile":"up"}]` + if err := ioutil.WriteFile(cfgPath, []byte(cliJSON), base.LocalFileMode); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(credPath, []byte(credJSON), base.LocalFileMode); err != nil { + t.Fatal(err) + } + + m, err := base.NewAggConfigManager(cfgPath, credPath) + if err != nil { + t.Fatal(err) + } + + // GetBizClient 会改写包级全局 ClientConfig/AuthCredential,恢复现场避免测试顺序耦合 + oldM, oldCC, oldAC := base.AggConfigListIns, base.ClientConfig, base.AuthCredential + base.AggConfigListIns = m + defer func() { + base.AggConfigListIns, base.ClientConfig, base.AuthCredential = oldM, oldCC, oldAC + }() + + cmd := NewCmdConfigUpdate() + if err := cmd.Flags().Set("profile", "up"); err != nil { + t.Fatal(err) + } + if err := cmd.Flags().Set("base-url", gateway.URL); err != nil { + t.Fatal(err) + } + cmd.Run(cmd, nil) + + // 重新读盘验证持久化,而非只看内存 + m2, err := base.NewAggConfigManager(cfgPath, credPath) + if err != nil { + t.Fatal(err) + } + got, ok := m2.GetAggConfigByProfile("up") + if !ok { + t.Fatal("profile up missing after reload") + } + if got.BaseURL != gateway.URL { + t.Errorf("base_url on disk = %q, want new gateway %q (remote validation must run against the NEW base-url)", got.BaseURL, gateway.URL) + } +} + +// 回归:OAuth-only profile(auth_mode=oauth 且未存 AK/SK,auth login 直接创建的形态) +// 执行 init 确认切回 AK/SK 并走完整配置流程后,末尾持久化不能因 profile 已存在而失败, +// 否则整套新配置(密钥、region、project)全部不落盘 +func TestInitSaveOverwritesExistingOAuthOnlyProfile(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, "config.json") + credPath := filepath.Join(dir, "credential.json") + cliJSON := `[{"profile":"oa","active":true,"base_url":"https://api.ucloud.cn/","timeout_sec":15,"max_retry_times":3}]` + credJSON := `[{"public_key":"","private_key":"","profile":"oa","auth_mode":"oauth","access_token":"at","refresh_token":"rt","expires_at":1234567890}]` + if err := ioutil.WriteFile(cfgPath, []byte(cliJSON), base.LocalFileMode); err != nil { + t.Fatal(err) + } + if err := ioutil.WriteFile(credPath, []byte(credJSON), base.LocalFileMode); err != nil { + t.Fatal(err) + } + + m, err := base.NewAggConfigManager(cfgPath, credPath) + if err != nil { + t.Fatal(err) + } + cfg, ok := m.GetAggConfigByProfile("oa") + if !ok { + t.Fatal("profile oa missing") + } + + oldM, oldC := base.AggConfigListIns, base.ConfigIns + base.AggConfigListIns, base.ConfigIns = m, cfg + defer func() { base.AggConfigListIns, base.ConfigIns = oldM, oldC }() + + // 模拟 NewCmdInit Run 完整配置路径对 ConfigIns(即 manager map 内同一指针)的写入 + clearOAuthState(cfg) + cfg.PublicKey = "newpub" + cfg.PrivateKey = "newpri" + cfg.Region = "cn-bj2" + cfg.Zone = "cn-bj2-04" + cfg.ProjectID = "org-new" + cfg.Timeout = base.DefaultTimeoutSec + cfg.BaseURL = base.DefaultBaseURL + cfg.Active = true + + if err := saveInitProfile(cfg); err != nil { + t.Fatalf("save must overwrite existing profile instead of failing, got: %v", err) + } + + // 重新读盘验证持久化,而非只看内存 + m2, err := base.NewAggConfigManager(cfgPath, credPath) + if err != nil { + t.Fatal(err) + } + got, ok := m2.GetAggConfigByProfile("oa") + if !ok { + t.Fatal("profile oa missing after reload") + } + if got.PublicKey != "newpub" || got.PrivateKey != "newpri" { + t.Errorf("new AK/SK must land on disk, got public_key=%q private_key=%q", got.PublicKey, got.PrivateKey) + } + if got.Region != "cn-bj2" || got.Zone != "cn-bj2-04" || got.ProjectID != "org-new" { + t.Errorf("region/zone/project must land on disk, got region=%q zone=%q project_id=%q", got.Region, got.Zone, got.ProjectID) + } + if got.AuthMode != "" { + t.Errorf("auth_mode must be cleared on disk, got %q", got.AuthMode) + } + // 切回 AK/SK 后 token 必须清除,口径与 switchProfileToAKSK / 'ucloud auth logout' 一致 + if got.AccessToken != "" || got.RefreshToken != "" || got.ExpiresAt != 0 { + t.Errorf("oauth tokens must be cleared on disk, got access_token=%q refresh_token=%q expires_at=%d", + got.AccessToken, got.RefreshToken, got.ExpiresAt) + } +} diff --git a/cmd/login.go b/cmd/login.go new file mode 100644 index 0000000000..0706da3b2c --- /dev/null +++ b/cmd/login.go @@ -0,0 +1,307 @@ +// cmd/login.go +package cmd + +import ( + "bufio" + "fmt" + "os" + "time" + + "github.com/spf13/cobra" + + "github.com/ucloud/ucloud-sdk-go/services/uaccount" + + "github.com/ucloud/ucloud-cli/base" +) + +const loginLongHelp = `Log in to UCloud via your browser (OAuth authorization code flow). + +How it works (default): + 1. ucloud-cli opens your browser at the UCloud authorization page. + 2. You log in and approve. The browser is redirected to a local callback + that ucloud-cli is listening on — captured automatically, no copy-paste. + 3. ucloud-cli exchanges the code for tokens, saves them to + ~/.ucloud/credential.json (0600), and auto-configures the default + region/zone/project for this profile. + +Headless / SSH: pass --no-browser. ucloud-cli prints the authorization URL; +open it on any device, log in, then copy the FULL callback URL from the +address bar and paste it back into the terminal. + +Tokens are valid for about 1 hour and renew silently via the refresh token. +OAuth login targets interactive human use. For scripts and CI/CD, use an +AK/SK profile instead: ucloud config --profile --public-key ... --private-key ...` + +// oauthHelpTmpl 在全局 helpTmpl(不渲染 Long)前面补上 Long 段,仅作用于 login/logout +const oauthHelpTmpl = `{{with (or .Long .Short)}}{{. | trimTrailingWhitespaces}} + +{{end}}` + helpTmpl + +// NewCmdAuth ucloud auth 命令组:浏览器登录相关子命令 +func NewCmdAuth() *cobra.Command { + cmd := &cobra.Command{ + Use: "auth", + Short: "Authenticate ucloud-cli via browser (OAuth)", + Long: "Browser-based OAuth authentication for ucloud-cli. Subcommands: login, logout", + } + cmd.SetHelpTemplate(oauthHelpTmpl) + cmd.AddCommand(NewCmdLogin()) + cmd.AddCommand(NewCmdLogout()) + return cmd +} + +// NewCmdLogin ucloud auth login +func NewCmdLogin() *cobra.Command { + var noBrowser bool + cmd := &cobra.Command{ + Use: "login", + Short: "Log in to UCloud via browser (OAuth)", + Long: loginLongHelp, + Args: cobra.NoArgs, + Example: "ucloud auth login\nucloud auth login --no-browser", + Run: func(cmd *cobra.Command, args []string) { + runLogin(noBrowser) + }, + } + cmd.Flags().BoolVar(&noBrowser, "no-browser", false, "Print the authorization URL instead of opening a browser (for headless/SSH environments)") + cmd.SetHelpTemplate(oauthHelpTmpl) + return cmd +} + +func runLogin(noBrowser bool) { + // AP-1:非 TTY fail-fast + if !base.IsStdinTTY() { + fmt.Fprintln(os.Stderr, "'ucloud auth login' requires an interactive terminal. For automation/CI, use an AK/SK profile: ucloud config --profile --public-key --private-key ") + os.Exit(1) + } + + cfg := base.ConfigIns + oauthBase, err := base.GetOAuthBaseURL(cfg) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + state, err := base.GenerateState() + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + + var code, redirectURI string + if noBrowser { + code, redirectURI = runLoginManual(oauthBase, state) + } else { + code, redirectURI = runLoginAuto(oauthBase, state) + } + + tr, err := base.ExchangeToken(oauthBase, redirectURI, code) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + + // D5:已有 AK/SK 时打印一行告知 + 回退指引 + if cfg.PublicKey != "" || cfg.PrivateKey != "" { + fmt.Printf("Note: profile '%s' had AK/SK configured; it now switches to OAuth (auth_mode=oauth). AK/SK keys are kept; to switch back, run 'ucloud auth logout' then 'ucloud init'\n", cfg.Profile) + } + + base.ApplyTokenResponse(cfg, tr) + cfg.Active = true + if _, ok := base.AggConfigListIns.GetAggConfigByProfile(cfg.Profile); ok { + err = base.AggConfigListIns.UpdateAggConfig(cfg) + } else { + err = base.AggConfigListIns.Append(cfg) + } + if err != nil { + fmt.Fprintf(os.Stderr, "save credential failed: %v\n", err) + os.Exit(1) + } + + // AP-2:首登补链——自动配置 region/zone/project(Bearer 调用,复用 init 逻辑) + if cfg.Region == "" || cfg.Zone == "" { + region, rerr := fetchRegionWithConfig(cfg) + if rerr != nil { + fmt.Printf("Warning: fetch default region failed (%v). Set it later: ucloud config update --profile %s --region --zone \n", rerr, cfg.Profile) + } else { + cfg.Region = region.DefaultRegion + cfg.Zone = region.DefaultZone + fmt.Printf("Configured default region:%s zone:%s\n", cfg.Region, cfg.Zone) + } + } + // 既有 project_id 也要用新账号的项目列表校验:跨账号/跨站点遗留的 project_id + // 若原样保留,后续业务命令全部 RetCode 292 "Project not exists" + if projects, perr := fetchProjectListWithConfig(cfg); perr != nil { + fmt.Printf("Warning: fetch project list failed (%v). Set it later: ucloud config update --profile %s --project-id \n", perr, cfg.Profile) + } else if id, notice, rerr := resolveLoginProject(cfg.ProjectID, projects); rerr != nil { + fmt.Printf("Warning: resolve default project failed (%v). Set it later: ucloud config update --profile %s --project-id \n", rerr, cfg.Profile) + } else { + cfg.ProjectID = id + if notice != "" { + fmt.Println(notice) + } + } + if err := base.AggConfigListIns.UpdateAggConfig(cfg); err != nil { + fmt.Printf("Warning: saving default region/project failed (%v). Set them later: ucloud config update --profile %s --region --zone --project-id \n", err, cfg.Profile) + } + + // ⑥ 输出 email + 过期时间(id_token 仅解析不落盘) + until := time.Unix(cfg.ExpiresAt, 0).Format("15:04") + if email, eerr := base.ParseIDTokenEmail(tr.IDToken); eerr == nil && email != "" { + fmt.Printf("Logged in as %s, token valid until %s\n", email, until) + } else { + fmt.Printf("Logged in, token valid until %s\n", until) + } +} + +// resolveLoginProject 决定登录后 profile 应使用的 project(AP-2 的校验补丁): +// existing 为空 → 选账号默认项目(首登补链);existing 在列表内 → 保持不变,无提示; +// existing 不在列表内(跨账号/跨站点遗留)→ 切到默认项目并返回提示。 +// 返回 (projectID, notice);notice 非空时调用方原样打印。列表无默认项目时返回 errNoDefaultProject。 +func resolveLoginProject(existing string, projects []uaccount.ProjectListInfo) (string, string, error) { + var defaultID, defaultName string + for _, p := range projects { + if existing != "" && p.ProjectId == existing { + return existing, "", nil + } + if p.IsDefault { + defaultID, defaultName = p.ProjectId, p.ProjectName + } + } + if defaultID == "" { + return "", "", errNoDefaultProject + } + if existing == "" { + return defaultID, fmt.Sprintf("Configured default project:%s %s", defaultID, defaultName), nil + } + notice := fmt.Sprintf("Existing project '%s' does not belong to this account; switching to default project '%s' %s", existing, defaultID, defaultName) + return defaultID, notice, nil +} + +// loginCallbackTimeout 自动捕获的等待上限;超时回退到手工粘贴 +const loginCallbackTimeout = 3 * time.Minute + +// runLoginManual --no-browser 手工模式:分配一个 >=1024 端口(仅取号,立即释放 listener), +// 打印 URL,从 stdin 读回调 URL。返回 (code, redirectURI);出错时直接退出。 +func runLoginManual(oauthBase, state string) (string, string) { + ln, port, err := allocateLoopbackListener() + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + ln.Close() + redirectURI := base.BuildLoopbackRedirectURI(port) + authorizeURL := base.BuildAuthorizeURL(oauthBase, redirectURI, state) + + fmt.Println("Logging in via browser (manual paste). 3 steps:") + fmt.Println(" 1. Open the URL below and finish login & authorization.") + fmt.Println(" 2. The browser will be redirected to a localhost page that CANNOT") + fmt.Printf(" open (%s?...). THIS IS EXPECTED.\n", redirectURI) + fmt.Println(" 3. Copy the FULL URL from the address bar and paste it here.") + fmt.Println() + fmt.Printf("Open this URL in your browser:\n\n %s\n\n", authorizeURL) + + code, err := readCallbackCode(state) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + return code, redirectURI +} + +// runLoginAuto 默认模式:起本地回调 server 自动捕获 code,超时回退手工粘贴。 +// 返回 (code, redirectURI);遇到主动错误(拒绝授权/state 不匹配)直接退出。 +func runLoginAuto(oauthBase, state string) (string, string) { + ln, port, err := allocateLoopbackListener() + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + redirectURI := base.BuildLoopbackRedirectURI(port) + authorizeURL := base.BuildAuthorizeURL(oauthBase, redirectURI, state) + + srv, ch := startCallbackServer(ln, state) + + fmt.Println("A browser window will open; finish the login there and return here — no copy-paste needed.") + fmt.Printf("If it does not open, visit:\n\n %s\n\n", authorizeURL) + openbrowser(authorizeURL) + + select { + case res := <-ch: + srv.Close() + if res.err != nil { + fmt.Fprintln(os.Stderr, res.err) + os.Exit(1) + } + return res.code, redirectURI + case <-time.After(loginCallbackTimeout): + srv.Close() + fmt.Fprintln(os.Stderr, "Automatic capture timed out. Paste the callback URL here as a fallback:") + code, err := readCallbackCode(state) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + return code, redirectURI + } +} + +// readCallbackCode 读回调 URL:容忍折行(粘贴的多行一次到达时合并),允许重试 3 次 +func readCallbackCode(state string) (string, error) { + reader := bufio.NewReader(os.Stdin) + for attempt := 1; attempt <= 3; attempt++ { + fmt.Print("Paste the full callback URL here: ") + raw, err := readWrappedLine(reader) + if err != nil { + return "", fmt.Errorf("read input failed: %v", err) + } + code, perr := base.ParseCallbackURL(raw, state) + if perr == nil { + return code, nil + } + fmt.Fprintln(os.Stderr, perr) + } + return "", fmt.Errorf("too many invalid inputs. Run 'ucloud auth login' again") +} + +// readWrappedLine 读一行;若粘贴内容因终端折行带来多行(缓冲区中仍有数据),继续读完合并 +func readWrappedLine(r *bufio.Reader) (string, error) { + line, err := r.ReadString('\n') + if err != nil && line == "" { + return "", err + } + for r.Buffered() > 0 { + next, nerr := r.ReadString('\n') + line += next + if nerr != nil { + break + } + } + return line, nil +} + +// NewCmdLogout ucloud auth logout +func NewCmdLogout() *cobra.Command { + cmd := &cobra.Command{ + Use: "logout", + Short: "Log out: remove local OAuth tokens of the current profile", + Long: "Log out: remove local OAuth tokens (access_token/refresh_token) of the current profile from ~/.ucloud/credential.json", + Args: cobra.NoArgs, + Example: "ucloud auth logout", + Run: func(cmd *cobra.Command, args []string) { + cfg := base.ConfigIns + if cfg.AuthMode != base.AuthModeOAuth && cfg.AccessToken == "" { + fmt.Printf("Profile '%s' is not logged in via OAuth, nothing to do\n", cfg.Profile) + return + } + clearOAuthState(cfg) + if err := base.AggConfigListIns.UpdateAggConfig(cfg); err != nil { + base.HandleError(err) + return + } + // AP-4:不加服务端有效期提示(用户裁定,spec 风险 #5) + fmt.Printf("Logged out: local tokens of profile '%s' removed\n", cfg.Profile) + }, + } + cmd.SetHelpTemplate(oauthHelpTmpl) + return cmd +} diff --git a/cmd/login_test.go b/cmd/login_test.go new file mode 100644 index 0000000000..fc3118457b --- /dev/null +++ b/cmd/login_test.go @@ -0,0 +1,59 @@ +package cmd + +import ( + "strings" + "testing" + + "github.com/ucloud/ucloud-sdk-go/services/uaccount" +) + +// 回归:auth login 后已有 project_id 必须用新账号的项目列表校验。 +// 跨账号/跨站点遗留的 project_id 若原样保留,后续业务命令全部 RetCode 292 "Project not exists"。 +func TestResolveLoginProject(t *testing.T) { + projects := []uaccount.ProjectListInfo{ + {ProjectId: "org-111", ProjectName: "Default", IsDefault: true}, + {ProjectId: "org-222", ProjectName: "Dev"}, + } + + // case 1: project_id 为空 → 选账号默认项目(原有首登补链行为) + id, notice, err := resolveLoginProject("", projects) + if err != nil { + t.Fatalf("empty existing: unexpected error: %v", err) + } + if id != "org-111" { + t.Errorf("empty existing: id = %q, want default org-111", id) + } + if !strings.Contains(notice, "org-111") || !strings.Contains(notice, "Default") { + t.Errorf("empty existing: notice = %q, want it to mention default project id and name", notice) + } + + // case 2: project_id 属于当前账号 → 保持不变且无提示(AP-2 不覆写用户设置) + id, notice, err = resolveLoginProject("org-222", projects) + if err != nil { + t.Fatalf("existing in list: unexpected error: %v", err) + } + if id != "org-222" { + t.Errorf("existing in list: id = %q, want kept org-222", id) + } + if notice != "" { + t.Errorf("existing in list: notice = %q, want empty (no behavior change)", notice) + } + + // case 3: project_id 不属于当前账号 → 切到默认项目并给出明确提示 + id, notice, err = resolveLoginProject("org-stale", projects) + if err != nil { + t.Fatalf("existing not in list: unexpected error: %v", err) + } + if id != "org-111" { + t.Errorf("existing not in list: id = %q, want default org-111", id) + } + if !strings.Contains(notice, "org-stale") || !strings.Contains(notice, "org-111") { + t.Errorf("existing not in list: notice = %q, want it to mention stale id and new default", notice) + } + + // case 4: 列表里没有默认项目 → 返回错误(调用方仅告警,不中断登录) + noDefault := []uaccount.ProjectListInfo{{ProjectId: "org-333", ProjectName: "Solo"}} + if _, _, err = resolveLoginProject("org-stale", noDefault); err == nil { + t.Error("no default project: want error, got nil") + } +} diff --git a/cmd/region.go b/cmd/region.go index 17275532fb..42f7b51c05 100644 --- a/cmd/region.go +++ b/cmd/region.go @@ -201,6 +201,21 @@ func getDefaultProjectWithConfig(cfg *base.AggConfig) (string, string, error) { return "", "", errNoDefaultProject } +// fetchProjectListWithConfig 用指定 profile 的凭证拉取完整项目列表(含默认标记) +func fetchProjectListWithConfig(cfg *base.AggConfig) ([]uaccount.ProjectListInfo, error) { + bc, err := base.GetBizClient(cfg) + if err != nil { + return nil, err + } + + req := bc.NewGetProjectListRequest() + resp, err := bc.GetProjectList(req) + if err != nil { + return nil, err + } + return resp.ProjectSet, nil +} + func fetchProjectWithConfig(cfg *base.AggConfig) (map[string]bool, error) { bc, err := base.GetBizClient(cfg) if err != nil { diff --git a/cmd/root.go b/cmd/root.go index 283f9f4a7f..cf96d30da8 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -107,6 +107,7 @@ Use "{{.CommandPath}} --help" for details.{{end}} func addChildren(root *cobra.Command) { out := base.Cxt.GetWriter() root.AddCommand(NewCmdInit()) + root.AddCommand(NewCmdAuth()) root.AddCommand(NewCmdDoc(out)) root.AddCommand(NewCmdConfig()) root.AddCommand(NewCmdRegion(out)) @@ -131,7 +132,7 @@ func addChildren(root *cobra.Command) { root.AddCommand(NewCmdAPI(out)) root.AddCommand(NewCmdSignature()) for _, c := range root.Commands() { - if c.Name() != "init" && c.Name() != "gendoc" && c.Name() != "config" { + if c.Name() != "init" && c.Name() != "gendoc" && c.Name() != "config" && c.Name() != "auth" { c.PersistentFlags().StringVar(&global.PublicKey, "public-key", global.PublicKey, "Set public-key to override the public-key in local config file") c.PersistentFlags().StringVar(&global.PrivateKey, "private-key", global.PrivateKey, "Set private-key to override the private-key in local config file") c.PersistentFlags().StringVar(&global.BaseURL, "base-url", "", "Set base-url to override the base-url in local config file") @@ -144,6 +145,13 @@ func addChildren(root *cobra.Command) { // Execute adds all child commands to the root command and sets flags appropriately. // This is called by main.main(). It only needs to happen once to the rootCmd. func Execute() { + // Phase 3 脱敏扩面:panic 路径兜底,避免 panic 消息(可能含 token/header)原样落到 stderr + defer func() { + if r := recover(); r != nil { + fmt.Fprintln(os.Stderr, base.Redact(fmt.Sprintf("panic: %v", r))) + os.Exit(1) + } + }() cmd := NewCmdRoot() if base.InCloudShell { err := base.InitConfigInCloudShell() @@ -156,7 +164,7 @@ func Execute() { mode := os.Getenv("UCLOUD_CLI_DEBUG") if mode == "on" || global.Debug { base.ClientConfig.LogLevel = log.DebugLevel - base.BizClient = base.NewClient(base.ClientConfig, base.AuthCredential) + base.BizClient = base.NewClient(base.ClientConfig, base.AuthCredential, base.ConfigIns) } addChildren(cmd) @@ -236,17 +244,64 @@ func initialize(cmd *cobra.Command) { base.ClientConfig.Zone = zone } - if (cmd.Name() != "config" && cmd.Name() != "init" && cmd.Name() != "version") && (cmd.Parent() != nil && cmd.Parent().Name() != "config") { - if base.InCloudShell { - return + if isAuthSkippedCmd(cmd) { + return + } + if base.InCloudShell { + return + } + + if base.ConfigIns.AuthMode == base.AuthModeOAuth { + // AP-1:oauth 凭据缺失/失效 → stderr + 非零退出(不复制下方 aksk 路径的 exit 0 反模式) + isTTY := base.IsStdinTTY() + if msg, ok := base.CheckOAuthRunnable(base.ConfigIns, isTTY); !ok { + fmt.Fprintln(os.Stderr, msg) + os.Exit(1) + } + if err := base.EnsureFreshToken(base.ConfigIns, base.AggConfigListIns); err != nil { + fmt.Fprintln(os.Stderr, base.OAuthRefreshFailedHint(base.ConfigIns.Profile, isTTY, err)) + os.Exit(1) } - if base.ConfigIns.PrivateKey == "" { - base.Cxt.Println("private-key is empty. Execute command 'ucloud init|config' to configure it or run 'ucloud config list' to check your configurations") - os.Exit(0) + // 刷新可能换了 token,重建 client 让新 Bearer 生效。 + // GetBizClient 会重建 ClientConfig 并硬编码 FatalLevel,若 Execute 已按 + // UCLOUD_CLI_DEBUG 设了 DebugLevel,需在重建后恢复(SDK logger 在 + // NewClient 时捕获 LogLevel,必须再 rebuild 一次才生效)。 + debugOn := base.ClientConfig.LogLevel == log.DebugLevel + bc, err := base.GetBizClient(base.ConfigIns) + if err != nil { + base.HandleError(err) + } else { + base.BizClient = bc } - if base.ConfigIns.PublicKey == "" { - base.Cxt.Println("public-key is empty. Execute command 'ucloud init|config' to configure it or run 'ucloud config list' to check your configurations") - os.Exit(0) + if debugOn { + base.ClientConfig.LogLevel = log.DebugLevel + base.BizClient = base.NewClient(base.ClientConfig, base.AuthCredential, base.ConfigIns) } + return + } + + // 既有 AK/SK 检查,原样保留(CRITICAL 回归约束:行为与文案零变化) + if base.ConfigIns.PrivateKey == "" { + base.Cxt.Println("private-key is empty. Execute command 'ucloud init|config' to configure it or run 'ucloud config list' to check your configurations") + os.Exit(0) + } + if base.ConfigIns.PublicKey == "" { + base.Cxt.Println("public-key is empty. Execute command 'ucloud init|config' to configure it or run 'ucloud config list' to check your configurations") + os.Exit(0) + } +} + +// isAuthSkippedCmd 启动凭据检查跳过清单(D7:login/logout/help/version/config/init) +func isAuthSkippedCmd(cmd *cobra.Command) bool { + if cmd.Parent() == nil { + return true // root 命令本身(--version/--config/help),与历史行为一致 + } + switch cmd.Name() { + case "config", "init", "version", "login", "logout", "help", "auth": + return true + } + if cmd.Parent() != nil && (cmd.Parent().Name() == "config" || cmd.Parent().Name() == "auth") { + return true } + return false } diff --git a/go.mod b/go.mod index 58b024d2b4..c77f9b186b 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,8 @@ go 1.19 require ( github.com/fatih/color v1.13.0 + github.com/gofrs/flock v0.8.1 + github.com/mattn/go-isatty v0.0.14 github.com/satori/go.uuid v1.2.0 github.com/sirupsen/logrus v1.3.0 github.com/spf13/cobra v0.0.3 @@ -18,7 +20,6 @@ require ( github.com/konsorten/go-windows-terminal-sequences v1.0.1 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/mattn/go-colorable v0.1.9 // indirect - github.com/mattn/go-isatty v0.0.14 // indirect github.com/pkg/errors v0.8.0 // indirect github.com/russross/blackfriday v1.5.2 // indirect golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 // indirect diff --git a/go.sum b/go.sum index 587a20d04f..ce2a8b3875 100644 --- a/go.sum +++ b/go.sum @@ -9,6 +9,8 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/fatih/color v1.13.0 h1:8LOYc1KYPPmyKMuN8QV2DNRWNbLo6LZ0iLs8+mlH53w= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/gofrs/flock v0.8.1 h1:+gYjHKf32LDeiEEFhQaotPbLuUXjY5ZqxKgXy7n59aw= +github.com/gofrs/flock v0.8.1/go.mod h1:F1TvTiK9OcQqauNUHlbJvyl9Qa1QvF/gOUDKA14jxHU= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= @@ -49,8 +51,6 @@ github.com/spf13/viper v1.3.2/go.mod h1:ZiWeW+zYFKm7srdB9IoDzzZXaJaI5eL9QjNiN/DM github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= -github.com/ucloud/ucloud-sdk-go v0.22.17 h1:EFn+GxVKS5Tj8hIPie3qL6Zgk25fmWcHqJ06K8wl+Qo= -github.com/ucloud/ucloud-sdk-go v0.22.17/go.mod h1:dyLmFHmUfgb4RZKYQP9IArlvQ2pxzFthfhwxRzOEPIw= github.com/ucloud/ucloud-sdk-go v0.22.25 h1:ceKeH7WFnpUt9nJSubn+mnxS1iKGrk/Q+HLwa0iYwmQ= github.com/ucloud/ucloud-sdk-go v0.22.25/go.mod h1:dyLmFHmUfgb4RZKYQP9IArlvQ2pxzFthfhwxRzOEPIw= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= diff --git a/tests/oauth_cli_matrix.sh b/tests/oauth_cli_matrix.sh new file mode 100755 index 0000000000..5c18747a92 --- /dev/null +++ b/tests/oauth_cli_matrix.sh @@ -0,0 +1,89 @@ +#!/usr/bin/env bash +# tests/oauth_cli_matrix.sh — CLI 环境矩阵(D8):非 TTY / 无浏览器 / stdin pipe / init↔login 共存 / profile 切换 +# 用法:bash tests/oauth_cli_matrix.sh +# 黑盒驱动构建产物,用独立 HOME 沙箱,不触碰真实 ~/.ucloud +set -u + +ROOT="$(cd "$(dirname "$0")/.." && pwd)" +BIN="$ROOT/out/ucloud-matrix-test" + +# 先用真实 HOME 构建(asdf/版本管理器的 go shim 依赖 $HOME 解析工具链),再切沙箱 HOME +go build -mod=vendor -o "$BIN" "$ROOT/main.go" || { echo "build failed"; exit 1; } + +SANDBOX="$(mktemp -d)" +export HOME="$SANDBOX" +PASS=0; FAIL=0 + +check() { # check + local desc="$1" want="$2" got="$3" out="$4" needle="$5" + if [ "$got" = "$want" ] && echo "$out" | grep -q "$needle"; then + PASS=$((PASS+1)); echo "[OK] $desc" + else + FAIL=$((FAIL+1)); echo "[FAIL] $desc (exit=$got want=$want; output: $out)" + fi +} + +# 1. 非 TTY login fail-fast:stderr + 非零退出 +ERR=$(echo "" | "$BIN" auth login 2>&1 >/dev/null); RC=$? +check "non-tty login fail-fast to stderr" 1 "$RC" "$ERR" "interactive terminal" + +# 2. stdin pipe 跑业务命令(无任何配置):aksk 路径既有提示零回归(注意:历史行为 exit 0,保持) +OUT=$("$BIN" region 2>&1 "$HOME/.ucloud/config.json" <<'EOF' +[{"project_id":"org-x","region":"cn-bj2","zone":"cn-bj2-04","base_url":"https://api.ucloud.cn/","timeout_sec":15,"profile":"default","active":true,"max_retry_times":3}] +EOF +cat > "$HOME/.ucloud/credential.json" <<'EOF' +[{"public_key":"","private_key":"","cookie":"","csrf_token":"","profile":"default","auth_mode":"oauth"}] +EOF +# 注意:非 TTY 用 pipe 模拟而非 &1 >/dev/null); RC=$? +check "oauth missing token: stderr + nonzero + AK/SK pointer (non-tty)" 1 "$RC" "$ERR" "AK/SK" + +# 4. logout 在未登录 profile 上幂等 +cat > "$HOME/.ucloud/credential.json" <<'EOF' +[{"public_key":"pub","private_key":"pri","cookie":"","csrf_token":"","profile":"default"}] +EOF +OUT=$("$BIN" auth logout 2>&1); RC=$? +check "logout on non-oauth profile is a no-op" 0 "$RC" "$OUT" "not logged in" + +# 5. profile 切换:oauth profile + aksk profile 并存,--profile 选中 aksk 的不受 oauth 影响 +cat > "$HOME/.ucloud/config.json" <<'EOF' +[{"project_id":"org-x","region":"cn-bj2","zone":"cn-bj2-04","base_url":"https://api.ucloud.cn/","timeout_sec":15,"profile":"oa","active":true,"max_retry_times":3}, + {"project_id":"org-y","region":"cn-bj2","zone":"cn-bj2-04","base_url":"https://api.ucloud.cn/","timeout_sec":15,"profile":"ak","active":false,"max_retry_times":3}] +EOF +cat > "$HOME/.ucloud/credential.json" <<'EOF' +[{"public_key":"","private_key":"","cookie":"","csrf_token":"","profile":"oa","auth_mode":"oauth"}, + {"public_key":"pub","private_key":"pri","cookie":"","csrf_token":"","profile":"ak"}] +EOF +OUT=$("$BIN" config list 2>&1); RC=$? +check "config list shows AuthMode column" 0 "$RC" "$OUT" "AuthMode" +ERR=$(echo "" | "$BIN" region --profile oa 2>&1 >/dev/null); RC=$? +check "profile switch: oauth profile without token fails nonzero" 1 "$RC" "$ERR" "Profile 'oa'" + +# 6. --no-browser flag 存在(help 可见;真实流程属手动 E2E) +OUT=$("$BIN" auth login --help 2>&1); RC=$? +check "login --help mentions --no-browser" 0 "$RC" "$OUT" "no-browser" + +# 7. init 在 oauth profile(已存 AK/SK)上确认 y:auth_mode/token 必须清除并落盘 +# base_url/oauth_base_url 指向不可达地址,printHello/refresh 立刻失败,不出外网;只断言盘上状态 +cat > "$HOME/.ucloud/config.json" <<'EOF' +[{"project_id":"org-x","region":"cn-bj2","zone":"cn-bj2-04","base_url":"http://127.0.0.1:1/","oauth_base_url":"http://127.0.0.1:1/","timeout_sec":15,"profile":"default","active":true,"max_retry_times":0}] +EOF +cat > "$HOME/.ucloud/credential.json" <<'EOF' +[{"public_key":"pub","private_key":"pri","cookie":"","csrf_token":"","profile":"default","auth_mode":"oauth","access_token":"at","refresh_token":"rt","expires_at":123}] +EOF +printf 'y\n' | "$BIN" init >/dev/null 2>&1 +if ! grep -q '"auth_mode"' "$HOME/.ucloud/credential.json" && grep -q '"public_key": *"pub"' "$HOME/.ucloud/credential.json"; then + PASS=$((PASS+1)); echo "[OK] init on oauth profile with AK/SK: confirm y clears auth_mode on disk" +else + FAIL=$((FAIL+1)); echo "[FAIL] init on oauth profile with AK/SK: confirm y clears auth_mode on disk (credential: $(cat "$HOME/.ucloud/credential.json"))" +fi + +echo "" +echo "matrix result: $PASS passed, $FAIL failed" +rm -rf "$SANDBOX" "$BIN" +[ "$FAIL" -eq 0 ] diff --git a/vendor/github.com/gofrs/flock/.gitignore b/vendor/github.com/gofrs/flock/.gitignore new file mode 100644 index 0000000000..daf913b1b3 --- /dev/null +++ b/vendor/github.com/gofrs/flock/.gitignore @@ -0,0 +1,24 @@ +# Compiled Object files, Static and Dynamic libs (Shared Objects) +*.o +*.a +*.so + +# Folders +_obj +_test + +# Architecture specific extensions/prefixes +*.[568vq] +[568vq].out + +*.cgo1.go +*.cgo2.c +_cgo_defun.c +_cgo_gotypes.go +_cgo_export.* + +_testmain.go + +*.exe +*.test +*.prof diff --git a/vendor/github.com/gofrs/flock/.travis.yml b/vendor/github.com/gofrs/flock/.travis.yml new file mode 100644 index 0000000000..b16d040fa8 --- /dev/null +++ b/vendor/github.com/gofrs/flock/.travis.yml @@ -0,0 +1,10 @@ +language: go +go: + - 1.14.x + - 1.15.x +script: go test -v -check.vv -race ./... +sudo: false +notifications: + email: + on_success: never + on_failure: always diff --git a/vendor/github.com/gofrs/flock/LICENSE b/vendor/github.com/gofrs/flock/LICENSE new file mode 100644 index 0000000000..8b8ff36fe4 --- /dev/null +++ b/vendor/github.com/gofrs/flock/LICENSE @@ -0,0 +1,27 @@ +Copyright (c) 2015-2020, Tim Heckman +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of gofrs nor the names of its contributors may be used + to endorse or promote products derived from this software without + specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/gofrs/flock/README.md b/vendor/github.com/gofrs/flock/README.md new file mode 100644 index 0000000000..71ce63692e --- /dev/null +++ b/vendor/github.com/gofrs/flock/README.md @@ -0,0 +1,41 @@ +# flock +[![TravisCI Build Status](https://img.shields.io/travis/gofrs/flock/master.svg?style=flat)](https://travis-ci.org/gofrs/flock) +[![GoDoc](https://img.shields.io/badge/godoc-flock-blue.svg?style=flat)](https://godoc.org/github.com/gofrs/flock) +[![License](https://img.shields.io/badge/license-BSD_3--Clause-brightgreen.svg?style=flat)](https://github.com/gofrs/flock/blob/master/LICENSE) +[![Go Report Card](https://goreportcard.com/badge/github.com/gofrs/flock)](https://goreportcard.com/report/github.com/gofrs/flock) + +`flock` implements a thread-safe sync.Locker interface for file locking. It also +includes a non-blocking TryLock() function to allow locking without blocking execution. + +## License +`flock` is released under the BSD 3-Clause License. See the `LICENSE` file for more details. + +## Go Compatibility +This package makes use of the `context` package that was introduced in Go 1.7. As such, this +package has an implicit dependency on Go 1.7+. + +## Installation +``` +go get -u github.com/gofrs/flock +``` + +## Usage +```Go +import "github.com/gofrs/flock" + +fileLock := flock.New("/var/lock/go-lock.lock") + +locked, err := fileLock.TryLock() + +if err != nil { + // handle locking error +} + +if locked { + // do work + fileLock.Unlock() +} +``` + +For more detailed usage information take a look at the package API docs on +[GoDoc](https://godoc.org/github.com/gofrs/flock). diff --git a/vendor/github.com/gofrs/flock/appveyor.yml b/vendor/github.com/gofrs/flock/appveyor.yml new file mode 100644 index 0000000000..909b4bf7cb --- /dev/null +++ b/vendor/github.com/gofrs/flock/appveyor.yml @@ -0,0 +1,25 @@ +version: '{build}' + +build: false +deploy: false + +clone_folder: 'c:\gopath\src\github.com\gofrs\flock' + +environment: + GOPATH: 'c:\gopath' + GOVERSION: '1.15' + +init: + - git config --global core.autocrlf input + +install: + - rmdir c:\go /s /q + - appveyor DownloadFile https://storage.googleapis.com/golang/go%GOVERSION%.windows-amd64.msi + - msiexec /i go%GOVERSION%.windows-amd64.msi /q + - set Path=c:\go\bin;c:\gopath\bin;%Path% + - go version + - go env + +test_script: + - go get -t ./... + - go test -race -v ./... diff --git a/vendor/github.com/gofrs/flock/flock.go b/vendor/github.com/gofrs/flock/flock.go new file mode 100644 index 0000000000..95c784ca50 --- /dev/null +++ b/vendor/github.com/gofrs/flock/flock.go @@ -0,0 +1,144 @@ +// Copyright 2015 Tim Heckman. All rights reserved. +// Use of this source code is governed by the BSD 3-Clause +// license that can be found in the LICENSE file. + +// Package flock implements a thread-safe interface for file locking. +// It also includes a non-blocking TryLock() function to allow locking +// without blocking execution. +// +// Package flock is released under the BSD 3-Clause License. See the LICENSE file +// for more details. +// +// While using this library, remember that the locking behaviors are not +// guaranteed to be the same on each platform. For example, some UNIX-like +// operating systems will transparently convert a shared lock to an exclusive +// lock. If you Unlock() the flock from a location where you believe that you +// have the shared lock, you may accidentally drop the exclusive lock. +package flock + +import ( + "context" + "os" + "runtime" + "sync" + "time" +) + +// Flock is the struct type to handle file locking. All fields are unexported, +// with access to some of the fields provided by getter methods (Path() and Locked()). +type Flock struct { + path string + m sync.RWMutex + fh *os.File + l bool + r bool +} + +// New returns a new instance of *Flock. The only parameter +// it takes is the path to the desired lockfile. +func New(path string) *Flock { + return &Flock{path: path} +} + +// NewFlock returns a new instance of *Flock. The only parameter +// it takes is the path to the desired lockfile. +// +// Deprecated: Use New instead. +func NewFlock(path string) *Flock { + return New(path) +} + +// Close is equivalent to calling Unlock. +// +// This will release the lock and close the underlying file descriptor. +// It will not remove the file from disk, that's up to your application. +func (f *Flock) Close() error { + return f.Unlock() +} + +// Path returns the path as provided in NewFlock(). +func (f *Flock) Path() string { + return f.path +} + +// Locked returns the lock state (locked: true, unlocked: false). +// +// Warning: by the time you use the returned value, the state may have changed. +func (f *Flock) Locked() bool { + f.m.RLock() + defer f.m.RUnlock() + return f.l +} + +// RLocked returns the read lock state (locked: true, unlocked: false). +// +// Warning: by the time you use the returned value, the state may have changed. +func (f *Flock) RLocked() bool { + f.m.RLock() + defer f.m.RUnlock() + return f.r +} + +func (f *Flock) String() string { + return f.path +} + +// TryLockContext repeatedly tries to take an exclusive lock until one of the +// conditions is met: TryLock succeeds, TryLock fails with error, or Context +// Done channel is closed. +func (f *Flock) TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) { + return tryCtx(ctx, f.TryLock, retryDelay) +} + +// TryRLockContext repeatedly tries to take a shared lock until one of the +// conditions is met: TryRLock succeeds, TryRLock fails with error, or Context +// Done channel is closed. +func (f *Flock) TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) { + return tryCtx(ctx, f.TryRLock, retryDelay) +} + +func tryCtx(ctx context.Context, fn func() (bool, error), retryDelay time.Duration) (bool, error) { + if ctx.Err() != nil { + return false, ctx.Err() + } + for { + if ok, err := fn(); ok || err != nil { + return ok, err + } + select { + case <-ctx.Done(): + return false, ctx.Err() + case <-time.After(retryDelay): + // try again + } + } +} + +func (f *Flock) setFh() error { + // open a new os.File instance + // create it if it doesn't exist, and open the file read-only. + flags := os.O_CREATE + if runtime.GOOS == "aix" { + // AIX cannot preform write-lock (ie exclusive) on a + // read-only file. + flags |= os.O_RDWR + } else { + flags |= os.O_RDONLY + } + fh, err := os.OpenFile(f.path, flags, os.FileMode(0600)) + if err != nil { + return err + } + + // set the filehandle on the struct + f.fh = fh + return nil +} + +// ensure the file handle is closed if no lock is held +func (f *Flock) ensureFhState() { + if !f.l && !f.r && f.fh != nil { + f.fh.Close() + f.fh = nil + } +} diff --git a/vendor/github.com/gofrs/flock/flock_aix.go b/vendor/github.com/gofrs/flock/flock_aix.go new file mode 100644 index 0000000000..7277c1b6b2 --- /dev/null +++ b/vendor/github.com/gofrs/flock/flock_aix.go @@ -0,0 +1,281 @@ +// Copyright 2019 Tim Heckman. All rights reserved. Use of this source code is +// governed by the BSD 3-Clause license that can be found in the LICENSE file. + +// Copyright 2018 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// This code implements the filelock API using POSIX 'fcntl' locks, which attach +// to an (inode, process) pair rather than a file descriptor. To avoid unlocking +// files prematurely when the same file is opened through different descriptors, +// we allow only one read-lock at a time. +// +// This code is adapted from the Go package: +// cmd/go/internal/lockedfile/internal/filelock + +//+build aix + +package flock + +import ( + "errors" + "io" + "os" + "sync" + "syscall" + + "golang.org/x/sys/unix" +) + +type lockType int16 + +const ( + readLock lockType = unix.F_RDLCK + writeLock lockType = unix.F_WRLCK +) + +type cmdType int + +const ( + tryLock cmdType = unix.F_SETLK + waitLock cmdType = unix.F_SETLKW +) + +type inode = uint64 + +type inodeLock struct { + owner *Flock + queue []<-chan *Flock +} + +var ( + mu sync.Mutex + inodes = map[*Flock]inode{} + locks = map[inode]inodeLock{} +) + +// Lock is a blocking call to try and take an exclusive file lock. It will wait +// until it is able to obtain the exclusive file lock. It's recommended that +// TryLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already exclusive-locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +// +// If the *Flock has a shared lock (RLock), this may transparently replace the +// shared lock with an exclusive lock on some UNIX-like operating systems. Be +// careful when using exclusive locks in conjunction with shared locks +// (RLock()), because calling Unlock() may accidentally release the exclusive +// lock that was once a shared lock. +func (f *Flock) Lock() error { + return f.lock(&f.l, writeLock) +} + +// RLock is a blocking call to try and take a shared file lock. It will wait +// until it is able to obtain the shared file lock. It's recommended that +// TryRLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already shared-locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +func (f *Flock) RLock() error { + return f.lock(&f.r, readLock) +} + +func (f *Flock) lock(locked *bool, flag lockType) error { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return err + } + defer f.ensureFhState() + } + + if _, err := f.doLock(waitLock, flag, true); err != nil { + return err + } + + *locked = true + return nil +} + +func (f *Flock) doLock(cmd cmdType, lt lockType, blocking bool) (bool, error) { + // POSIX locks apply per inode and process, and the lock for an inode is + // released when *any* descriptor for that inode is closed. So we need to + // synchronize access to each inode internally, and must serialize lock and + // unlock calls that refer to the same inode through different descriptors. + fi, err := f.fh.Stat() + if err != nil { + return false, err + } + ino := inode(fi.Sys().(*syscall.Stat_t).Ino) + + mu.Lock() + if i, dup := inodes[f]; dup && i != ino { + mu.Unlock() + return false, &os.PathError{ + Path: f.Path(), + Err: errors.New("inode for file changed since last Lock or RLock"), + } + } + + inodes[f] = ino + + var wait chan *Flock + l := locks[ino] + if l.owner == f { + // This file already owns the lock, but the call may change its lock type. + } else if l.owner == nil { + // No owner: it's ours now. + l.owner = f + } else if !blocking { + // Already owned: cannot take the lock. + mu.Unlock() + return false, nil + } else { + // Already owned: add a channel to wait on. + wait = make(chan *Flock) + l.queue = append(l.queue, wait) + } + locks[ino] = l + mu.Unlock() + + if wait != nil { + wait <- f + } + + err = setlkw(f.fh.Fd(), cmd, lt) + + if err != nil { + f.doUnlock() + if cmd == tryLock && err == unix.EACCES { + return false, nil + } + return false, err + } + + return true, nil +} + +func (f *Flock) Unlock() error { + f.m.Lock() + defer f.m.Unlock() + + // if we aren't locked or if the lockfile instance is nil + // just return a nil error because we are unlocked + if (!f.l && !f.r) || f.fh == nil { + return nil + } + + if err := f.doUnlock(); err != nil { + return err + } + + f.fh.Close() + + f.l = false + f.r = false + f.fh = nil + + return nil +} + +func (f *Flock) doUnlock() (err error) { + var owner *Flock + mu.Lock() + ino, ok := inodes[f] + if ok { + owner = locks[ino].owner + } + mu.Unlock() + + if owner == f { + err = setlkw(f.fh.Fd(), waitLock, unix.F_UNLCK) + } + + mu.Lock() + l := locks[ino] + if len(l.queue) == 0 { + // No waiters: remove the map entry. + delete(locks, ino) + } else { + // The first waiter is sending us their file now. + // Receive it and update the queue. + l.owner = <-l.queue[0] + l.queue = l.queue[1:] + locks[ino] = l + } + delete(inodes, f) + mu.Unlock() + + return err +} + +// TryLock is the preferred function for taking an exclusive file lock. This +// function takes an RW-mutex lock before it tries to lock the file, so there is +// the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the exclusive +// file lock, the function will return false instead of waiting for the lock. If +// we get the lock, we also set the *Flock instance as being exclusive-locked. +func (f *Flock) TryLock() (bool, error) { + return f.try(&f.l, writeLock) +} + +// TryRLock is the preferred function for taking a shared file lock. This +// function takes an RW-mutex lock before it tries to lock the file, so there is +// the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the shared file +// lock, the function will return false instead of waiting for the lock. If we +// get the lock, we also set the *Flock instance as being share-locked. +func (f *Flock) TryRLock() (bool, error) { + return f.try(&f.r, readLock) +} + +func (f *Flock) try(locked *bool, flag lockType) (bool, error) { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return true, nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return false, err + } + defer f.ensureFhState() + } + + haslock, err := f.doLock(tryLock, flag, false) + if err != nil { + return false, err + } + + *locked = haslock + return haslock, nil +} + +// setlkw calls FcntlFlock with cmd for the entire file indicated by fd. +func setlkw(fd uintptr, cmd cmdType, lt lockType) error { + for { + err := unix.FcntlFlock(fd, int(cmd), &unix.Flock_t{ + Type: int16(lt), + Whence: io.SeekStart, + Start: 0, + Len: 0, // All bytes. + }) + if err != unix.EINTR { + return err + } + } +} diff --git a/vendor/github.com/gofrs/flock/flock_unix.go b/vendor/github.com/gofrs/flock/flock_unix.go new file mode 100644 index 0000000000..c315a3e290 --- /dev/null +++ b/vendor/github.com/gofrs/flock/flock_unix.go @@ -0,0 +1,197 @@ +// Copyright 2015 Tim Heckman. All rights reserved. +// Use of this source code is governed by the BSD 3-Clause +// license that can be found in the LICENSE file. + +// +build !aix,!windows + +package flock + +import ( + "os" + "syscall" +) + +// Lock is a blocking call to try and take an exclusive file lock. It will wait +// until it is able to obtain the exclusive file lock. It's recommended that +// TryLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already exclusive-locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +// +// If the *Flock has a shared lock (RLock), this may transparently replace the +// shared lock with an exclusive lock on some UNIX-like operating systems. Be +// careful when using exclusive locks in conjunction with shared locks +// (RLock()), because calling Unlock() may accidentally release the exclusive +// lock that was once a shared lock. +func (f *Flock) Lock() error { + return f.lock(&f.l, syscall.LOCK_EX) +} + +// RLock is a blocking call to try and take a shared file lock. It will wait +// until it is able to obtain the shared file lock. It's recommended that +// TryRLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already shared-locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +func (f *Flock) RLock() error { + return f.lock(&f.r, syscall.LOCK_SH) +} + +func (f *Flock) lock(locked *bool, flag int) error { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return err + } + defer f.ensureFhState() + } + + if err := syscall.Flock(int(f.fh.Fd()), flag); err != nil { + shouldRetry, reopenErr := f.reopenFDOnError(err) + if reopenErr != nil { + return reopenErr + } + + if !shouldRetry { + return err + } + + if err = syscall.Flock(int(f.fh.Fd()), flag); err != nil { + return err + } + } + + *locked = true + return nil +} + +// Unlock is a function to unlock the file. This file takes a RW-mutex lock, so +// while it is running the Locked() and RLocked() functions will be blocked. +// +// This function short-circuits if we are unlocked already. If not, it calls +// syscall.LOCK_UN on the file and closes the file descriptor. It does not +// remove the file from disk. It's up to your application to do. +// +// Please note, if your shared lock became an exclusive lock this may +// unintentionally drop the exclusive lock if called by the consumer that +// believes they have a shared lock. Please see Lock() for more details. +func (f *Flock) Unlock() error { + f.m.Lock() + defer f.m.Unlock() + + // if we aren't locked or if the lockfile instance is nil + // just return a nil error because we are unlocked + if (!f.l && !f.r) || f.fh == nil { + return nil + } + + // mark the file as unlocked + if err := syscall.Flock(int(f.fh.Fd()), syscall.LOCK_UN); err != nil { + return err + } + + f.fh.Close() + + f.l = false + f.r = false + f.fh = nil + + return nil +} + +// TryLock is the preferred function for taking an exclusive file lock. This +// function takes an RW-mutex lock before it tries to lock the file, so there is +// the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the exclusive +// file lock, the function will return false instead of waiting for the lock. If +// we get the lock, we also set the *Flock instance as being exclusive-locked. +func (f *Flock) TryLock() (bool, error) { + return f.try(&f.l, syscall.LOCK_EX) +} + +// TryRLock is the preferred function for taking a shared file lock. This +// function takes an RW-mutex lock before it tries to lock the file, so there is +// the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the shared file +// lock, the function will return false instead of waiting for the lock. If we +// get the lock, we also set the *Flock instance as being share-locked. +func (f *Flock) TryRLock() (bool, error) { + return f.try(&f.r, syscall.LOCK_SH) +} + +func (f *Flock) try(locked *bool, flag int) (bool, error) { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return true, nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return false, err + } + defer f.ensureFhState() + } + + var retried bool +retry: + err := syscall.Flock(int(f.fh.Fd()), flag|syscall.LOCK_NB) + + switch err { + case syscall.EWOULDBLOCK: + return false, nil + case nil: + *locked = true + return true, nil + } + if !retried { + if shouldRetry, reopenErr := f.reopenFDOnError(err); reopenErr != nil { + return false, reopenErr + } else if shouldRetry { + retried = true + goto retry + } + } + + return false, err +} + +// reopenFDOnError determines whether we should reopen the file handle +// in readwrite mode and try again. This comes from util-linux/sys-utils/flock.c: +// Since Linux 3.4 (commit 55725513) +// Probably NFSv4 where flock() is emulated by fcntl(). +func (f *Flock) reopenFDOnError(err error) (bool, error) { + if err != syscall.EIO && err != syscall.EBADF { + return false, nil + } + if st, err := f.fh.Stat(); err == nil { + // if the file is able to be read and written + if st.Mode()&0600 == 0600 { + f.fh.Close() + f.fh = nil + + // reopen in read-write mode and set the filehandle + fh, err := os.OpenFile(f.path, os.O_CREATE|os.O_RDWR, os.FileMode(0600)) + if err != nil { + return false, err + } + f.fh = fh + return true, nil + } + } + + return false, nil +} diff --git a/vendor/github.com/gofrs/flock/flock_winapi.go b/vendor/github.com/gofrs/flock/flock_winapi.go new file mode 100644 index 0000000000..fe405a255a --- /dev/null +++ b/vendor/github.com/gofrs/flock/flock_winapi.go @@ -0,0 +1,76 @@ +// Copyright 2015 Tim Heckman. All rights reserved. +// Use of this source code is governed by the BSD 3-Clause +// license that can be found in the LICENSE file. + +// +build windows + +package flock + +import ( + "syscall" + "unsafe" +) + +var ( + kernel32, _ = syscall.LoadLibrary("kernel32.dll") + procLockFileEx, _ = syscall.GetProcAddress(kernel32, "LockFileEx") + procUnlockFileEx, _ = syscall.GetProcAddress(kernel32, "UnlockFileEx") +) + +const ( + winLockfileFailImmediately = 0x00000001 + winLockfileExclusiveLock = 0x00000002 + winLockfileSharedLock = 0x00000000 +) + +// Use of 0x00000000 for the shared lock is a guess based on some the MS Windows +// `LockFileEX` docs, which document the `LOCKFILE_EXCLUSIVE_LOCK` flag as: +// +// > The function requests an exclusive lock. Otherwise, it requests a shared +// > lock. +// +// https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx + +func lockFileEx(handle syscall.Handle, flags uint32, reserved uint32, numberOfBytesToLockLow uint32, numberOfBytesToLockHigh uint32, offset *syscall.Overlapped) (bool, syscall.Errno) { + r1, _, errNo := syscall.Syscall6( + uintptr(procLockFileEx), + 6, + uintptr(handle), + uintptr(flags), + uintptr(reserved), + uintptr(numberOfBytesToLockLow), + uintptr(numberOfBytesToLockHigh), + uintptr(unsafe.Pointer(offset))) + + if r1 != 1 { + if errNo == 0 { + return false, syscall.EINVAL + } + + return false, errNo + } + + return true, 0 +} + +func unlockFileEx(handle syscall.Handle, reserved uint32, numberOfBytesToLockLow uint32, numberOfBytesToLockHigh uint32, offset *syscall.Overlapped) (bool, syscall.Errno) { + r1, _, errNo := syscall.Syscall6( + uintptr(procUnlockFileEx), + 5, + uintptr(handle), + uintptr(reserved), + uintptr(numberOfBytesToLockLow), + uintptr(numberOfBytesToLockHigh), + uintptr(unsafe.Pointer(offset)), + 0) + + if r1 != 1 { + if errNo == 0 { + return false, syscall.EINVAL + } + + return false, errNo + } + + return true, 0 +} diff --git a/vendor/github.com/gofrs/flock/flock_windows.go b/vendor/github.com/gofrs/flock/flock_windows.go new file mode 100644 index 0000000000..ddb534ccef --- /dev/null +++ b/vendor/github.com/gofrs/flock/flock_windows.go @@ -0,0 +1,142 @@ +// Copyright 2015 Tim Heckman. All rights reserved. +// Use of this source code is governed by the BSD 3-Clause +// license that can be found in the LICENSE file. + +package flock + +import ( + "syscall" +) + +// ErrorLockViolation is the error code returned from the Windows syscall when a +// lock would block and you ask to fail immediately. +const ErrorLockViolation syscall.Errno = 0x21 // 33 + +// Lock is a blocking call to try and take an exclusive file lock. It will wait +// until it is able to obtain the exclusive file lock. It's recommended that +// TryLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +func (f *Flock) Lock() error { + return f.lock(&f.l, winLockfileExclusiveLock) +} + +// RLock is a blocking call to try and take a shared file lock. It will wait +// until it is able to obtain the shared file lock. It's recommended that +// TryRLock() be used over this function. This function may block the ability to +// query the current Locked() or RLocked() status due to a RW-mutex lock. +// +// If we are already locked, this function short-circuits and returns +// immediately assuming it can take the mutex lock. +func (f *Flock) RLock() error { + return f.lock(&f.r, winLockfileSharedLock) +} + +func (f *Flock) lock(locked *bool, flag uint32) error { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return err + } + defer f.ensureFhState() + } + + if _, errNo := lockFileEx(syscall.Handle(f.fh.Fd()), flag, 0, 1, 0, &syscall.Overlapped{}); errNo > 0 { + return errNo + } + + *locked = true + return nil +} + +// Unlock is a function to unlock the file. This file takes a RW-mutex lock, so +// while it is running the Locked() and RLocked() functions will be blocked. +// +// This function short-circuits if we are unlocked already. If not, it calls +// UnlockFileEx() on the file and closes the file descriptor. It does not remove +// the file from disk. It's up to your application to do. +func (f *Flock) Unlock() error { + f.m.Lock() + defer f.m.Unlock() + + // if we aren't locked or if the lockfile instance is nil + // just return a nil error because we are unlocked + if (!f.l && !f.r) || f.fh == nil { + return nil + } + + // mark the file as unlocked + if _, errNo := unlockFileEx(syscall.Handle(f.fh.Fd()), 0, 1, 0, &syscall.Overlapped{}); errNo > 0 { + return errNo + } + + f.fh.Close() + + f.l = false + f.r = false + f.fh = nil + + return nil +} + +// TryLock is the preferred function for taking an exclusive file lock. This +// function does take a RW-mutex lock before it tries to lock the file, so there +// is the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the exclusive +// file lock, the function will return false instead of waiting for the lock. If +// we get the lock, we also set the *Flock instance as being exclusive-locked. +func (f *Flock) TryLock() (bool, error) { + return f.try(&f.l, winLockfileExclusiveLock) +} + +// TryRLock is the preferred function for taking a shared file lock. This +// function does take a RW-mutex lock before it tries to lock the file, so there +// is the possibility that this function may block for a short time if another +// goroutine is trying to take any action. +// +// The actual file lock is non-blocking. If we are unable to get the shared file +// lock, the function will return false instead of waiting for the lock. If we +// get the lock, we also set the *Flock instance as being shared-locked. +func (f *Flock) TryRLock() (bool, error) { + return f.try(&f.r, winLockfileSharedLock) +} + +func (f *Flock) try(locked *bool, flag uint32) (bool, error) { + f.m.Lock() + defer f.m.Unlock() + + if *locked { + return true, nil + } + + if f.fh == nil { + if err := f.setFh(); err != nil { + return false, err + } + defer f.ensureFhState() + } + + _, errNo := lockFileEx(syscall.Handle(f.fh.Fd()), flag|winLockfileFailImmediately, 0, 1, 0, &syscall.Overlapped{}) + + if errNo > 0 { + if errNo == ErrorLockViolation || errNo == syscall.ERROR_IO_PENDING { + return false, nil + } + + return false, errNo + } + + *locked = true + + return true, nil +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 8c717cc1e8..b67c4470d3 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -4,6 +4,9 @@ github.com/cpuguy83/go-md2man/md2man # github.com/fatih/color v1.13.0 ## explicit; go 1.13 github.com/fatih/color +# github.com/gofrs/flock v0.8.1 +## explicit +github.com/gofrs/flock # github.com/inconshreveable/mousetrap v1.0.0 ## explicit github.com/inconshreveable/mousetrap