diff --git a/runner/internal/shim/components/utils.go b/runner/internal/shim/components/utils.go index a4456acaa3..5a5cb886ae 100644 --- a/runner/internal/shim/components/utils.go +++ b/runner/internal/shim/components/utils.go @@ -2,6 +2,8 @@ package components import ( "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "io" @@ -9,6 +11,7 @@ import ( "os" "os/exec" "path/filepath" + "sort" "strings" "time" @@ -16,79 +19,244 @@ import ( "github.com/dstackai/dstack/runner/internal/common/utils" ) -const downloadTimeout = 10 * time.Minute +const ( + downloadTimeout = 10 * time.Minute + cacheSuffix = ".cache" + etagSuffix = ".etag" + // Max per-version caches kept next to a binary (bounds disk use). + maxCachedVersions = 5 +) +// downloadFile ensures path holds the artifact at url. +// +// Bytes are cached next to path (one cache per URL, validated by ETag), so a repeated +// or forced install of an unchanged version returns 304 and transfers nothing. Cached +// bytes are then chmod+renamed into place; a failed rename retries from cache without +// re-downloading. With force=false an existing path is left as-is. func downloadFile(ctx context.Context, url string, path string, mode os.FileMode, force bool) error { - if _, err := os.Stat(path); err == nil { - if force { - log.Debug(ctx, "file exists, forcing download", "path", path) - } else { + if !force { + if _, err := os.Stat(path); err == nil { log.Debug(ctx, "file exists, skipping download", "path", path) return nil + } else if !os.IsNotExist(err) { + return fmt.Errorf("check file exists: %w", err) } - } else if !os.IsNotExist(err) { - return fmt.Errorf("check file exists: %w", err) } - dir, name := filepath.Split(path) - tempFile, err := os.CreateTemp(dir, fmt.Sprintf(".*-%s", name)) + + // One cache file per URL so several versions can coexist. With a single shared + // cache, a request for a different version would overwrite it and force a + // re-download every time the requested version changes. + key := urlKey(url) + cachePath := fmt.Sprintf("%s.%s%s", path, key, cacheSuffix) + etagPath := fmt.Sprintf("%s.%s%s", path, key, etagSuffix) + + downloaded, err := ensureCached(ctx, url, cachePath, etagPath) if err != nil { - return fmt.Errorf("create temp file for %s: %w", name, err) + return err } - defer func() { - if err := tempFile.Close(); err != nil { - log.Error(ctx, "close temp file", "err", err) + if downloaded { + pruneCaches(ctx, path, maxCachedVersions) + } + + // Install the cached bytes; skip if path already matches the cache. + if !downloaded { + installed, err := sameSize(path, cachePath) + if err != nil { + return err } - if err := os.Remove(tempFile.Name()); err != nil && !errors.Is(err, os.ErrNotExist) { - log.Error(ctx, "remove temp file", "err", err) + if installed { + return nil } - }() + } + return installFile(ctx, cachePath, path, mode) +} - log.Debug(ctx, "downloading", "path", path, "url", url) +// ensureCached makes cachePath hold url's current bytes, downloading only if the cache +// is missing or stale (per the stored ETag). Reports whether it downloaded. +func ensureCached(ctx context.Context, url string, cachePath string, etagPath string) (downloaded bool, err error) { ctx, cancel := context.WithTimeout(ctx, downloadTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { - return fmt.Errorf("create download request: %w", err) + return false, fmt.Errorf("create download request: %w", err) + } + // Revalidate only if we have both cached bytes and an ETag for them. + if exists, _ := fileExists(cachePath); exists { + if etag := readETag(etagPath); etag != "" { + req.Header.Set("If-None-Match", etag) + } } resp, err := http.DefaultClient.Do(req) if err != nil { - return fmt.Errorf("execute download request: %w", err) + return false, fmt.Errorf("execute download request: %w", err) } - defer func() { - err := resp.Body.Close() - if err != nil { - log.Error(ctx, "downloadFile: close body error", "err", err) + if cerr := resp.Body.Close(); cerr != nil { + log.Error(ctx, "downloadFile: close body error", "err", cerr) } }() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("unexpected status code %s downloading %s from %s", resp.Status, name, url) + switch resp.StatusCode { + case http.StatusNotModified: + log.Debug(ctx, "cached artifact is up to date, skipping download", "url", url) + return false, nil + case http.StatusOK: + // fall through to download + default: + return false, fmt.Errorf("unexpected status code %s downloading from %s", resp.Status, url) } - written, err := io.Copy(tempFile, resp.Body) + log.Debug(ctx, "downloading", "path", cachePath, "url", url) + written, err := writeAtomic(ctx, cachePath, resp.Body) if err != nil { - log.Error(ctx, "download file", "err", err, "bytes", written, "total", resp.ContentLength) - if err := os.Remove(tempFile.Name()); err != nil { - log.Error(ctx, "remove temp file", "err", err) - } - return fmt.Errorf("copy %s: %w", name, err) + return false, err } - log.Debug(ctx, "file has been downloaded", "path", path, "bytes", written) + log.Debug(ctx, "file has been downloaded", "path", cachePath, "bytes", written) - if err := tempFile.Chmod(mode); err != nil { - return fmt.Errorf("chmod %s: %w", path, err) + // Remember the ETag for next time (best effort; if absent, next run does a full GET). + if etag := resp.Header.Get("ETag"); etag != "" { + if werr := os.WriteFile(etagPath, []byte(etag), 0o644); werr != nil { + log.Warning(ctx, "failed to store etag", "path", etagPath, "err", werr) + } + } else if rerr := os.Remove(etagPath); rerr != nil && !errors.Is(rerr, os.ErrNotExist) { + log.Warning(ctx, "failed to remove stale etag", "path", etagPath, "err", rerr) } + return true, nil +} - if err := os.Rename(tempFile.Name(), path); err != nil { - return fmt.Errorf("move %s to %s: %w", name, path, err) +// installFile copies src to dst (with mode) via an atomic rename. No network, safe to retry. +func installFile(ctx context.Context, src string, dst string, mode os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return fmt.Errorf("open cache %s: %w", src, err) } + defer func() { _ = in.Close() }() + written, err := writeAtomicMode(ctx, dst, in, mode) + if err != nil { + return err + } + log.Debug(ctx, "file has been installed", "path", dst, "bytes", written) return nil } +// writeAtomic streams r into dst via a temp file and an atomic rename. +func writeAtomic(ctx context.Context, dst string, r io.Reader) (int64, error) { + return writeAtomicMode(ctx, dst, r, 0) +} + +// writeAtomicMode streams r into dst via a temp file and an atomic rename, setting the +// file mode (when non-zero) before the rename. +func writeAtomicMode(ctx context.Context, dst string, r io.Reader, mode os.FileMode) (int64, error) { + dir, name := filepath.Split(dst) + tmp, err := os.CreateTemp(dir, fmt.Sprintf(".*-%s", name)) + if err != nil { + return 0, fmt.Errorf("create temp file for %s: %w", name, err) + } + defer cleanupTemp(ctx, tmp) + + written, err := io.Copy(tmp, r) + if err != nil { + return written, fmt.Errorf("copy %s: %w", name, err) + } + if mode != 0 { + if err := tmp.Chmod(mode); err != nil { + return written, fmt.Errorf("chmod %s: %w", dst, err) + } + } + if err := tmp.Close(); err != nil { + return written, fmt.Errorf("close %s: %w", name, err) + } + if err := os.Rename(tmp.Name(), dst); err != nil { + return written, fmt.Errorf("move %s to %s: %w", name, dst, err) + } + return written, nil +} + +// cleanupTemp best-effort removes the temp file (already gone after a successful rename). +func cleanupTemp(ctx context.Context, f *os.File) { + _ = f.Close() // may already be closed + if err := os.Remove(f.Name()); err != nil && !errors.Is(err, os.ErrNotExist) { + log.Error(ctx, "remove temp file", "err", err) + } +} + +// pruneCaches keeps the `keep` newest per-version caches next to path (and their .etag), +// removing the rest. +func pruneCaches(ctx context.Context, path string, keep int) { + matches, err := filepath.Glob(path + ".*" + cacheSuffix) + if err != nil || len(matches) <= keep { + return + } + type cacheFile struct { + name string + mod time.Time + } + files := make([]cacheFile, 0, len(matches)) + for _, m := range matches { + fi, err := os.Stat(m) + if err != nil { + continue + } + files = append(files, cacheFile{m, fi.ModTime()}) + } + if len(files) <= keep { + return + } + sort.Slice(files, func(i, j int) bool { return files[i].mod.After(files[j].mod) }) + for _, f := range files[keep:] { + if err := os.Remove(f.name); err != nil && !errors.Is(err, os.ErrNotExist) { + log.Warning(ctx, "prune cache: remove", "path", f.name, "err", err) + } + etag := strings.TrimSuffix(f.name, cacheSuffix) + etagSuffix + if err := os.Remove(etag); err != nil && !errors.Is(err, os.ErrNotExist) { + log.Warning(ctx, "prune cache: remove etag", "path", etag, "err", err) + } + } +} + +// urlKey returns a short, filesystem-safe key derived from url, used to name its cache file. +func urlKey(url string) string { + sum := sha256.Sum256([]byte(url)) + return hex.EncodeToString(sum[:])[:16] +} + +func fileExists(path string) (bool, error) { + if _, err := os.Stat(path); err == nil { + return true, nil + } else if errors.Is(err, os.ErrNotExist) { + return false, nil + } else { + return false, err + } +} + +func readETag(path string) string { + data, err := os.ReadFile(path) + if err != nil { + return "" + } + return strings.TrimSpace(string(data)) +} + +// sameSize reports whether a and b both exist with the same size -- a cheap "is path +// already this cached binary?" check (different versions differ in size). +func sameSize(a string, b string) (bool, error) { + ai, err := os.Stat(a) + if errors.Is(err, os.ErrNotExist) { + return false, nil + } else if err != nil { + return false, err + } + bi, err := os.Stat(b) + if err != nil { + return false, err + } + return ai.Size() == bi.Size(), nil +} + func checkDstackComponent(ctx context.Context, name ComponentName, pth string) (status ComponentStatus, version string, err error) { exists, err := utils.PathExists(pth) if err != nil { diff --git a/runner/internal/shim/components/utils_test.go b/runner/internal/shim/components/utils_test.go new file mode 100644 index 0000000000..2a4f13c13a --- /dev/null +++ b/runner/internal/shim/components/utils_test.go @@ -0,0 +1,171 @@ +package components + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "sync/atomic" + "testing" +) + +// Repeatedly forcing an install of an unchanged artifact must not re-transfer the +// body -- the conditional request is answered with 304. This is the core guarantee +// that prevents the runaway re-download loop. +func TestDownloadFileSkipsUnchangedUnderForce(t *testing.T) { + const body = "runner-binary-v1" + const etag = `"v1"` + var bodyServed int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("If-None-Match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("ETag", etag) + bodyServed++ + _, _ = io.WriteString(w, body) + })) + defer srv.Close() + + path := filepath.Join(t.TempDir(), "dstack-runner") + + for i := 0; i < 5; i++ { + if err := downloadFile(context.Background(), srv.URL, path, 0o755, true); err != nil { + t.Fatalf("attempt %d: downloadFile: %v", i, err) + } + } + + if bodyServed != 1 { + t.Fatalf("body served %d times; want 1 (forced re-installs must not re-transfer unchanged bytes)", bodyServed) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatalf("read installed file: %v", err) + } + if string(got) != body { + t.Fatalf("installed content = %q; want %q", got, body) + } + info, err := os.Stat(path) + if err != nil { + t.Fatal(err) + } + if info.Mode().Perm() != 0o755 { + t.Fatalf("installed mode = %v; want 0755", info.Mode().Perm()) + } +} + +// When the artifact actually changes (different ETag), it must be re-downloaded and +// path updated. +func TestDownloadFileRedownloadsWhenChanged(t *testing.T) { + bodies := []string{"v1-bytes", "v2-bytes-longer"} + etags := []string{`"v1"`, `"v2"`} + cur := 0 + var bodyServed int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("If-None-Match") == etags[cur] { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("ETag", etags[cur]) + bodyServed++ + _, _ = io.WriteString(w, bodies[cur]) + })) + defer srv.Close() + + path := filepath.Join(t.TempDir(), "dstack-runner") + + if err := downloadFile(context.Background(), srv.URL, path, 0o755, true); err != nil { + t.Fatal(err) + } + cur = 1 // artifact changes upstream + if err := downloadFile(context.Background(), srv.URL, path, 0o755, true); err != nil { + t.Fatal(err) + } + if bodyServed != 2 { + t.Fatalf("body served %d times; want 2 (a changed artifact must be re-downloaded)", bodyServed) + } + got, _ := os.ReadFile(path) + if string(got) != bodies[1] { + t.Fatalf("installed content = %q; want %q", got, bodies[1]) + } +} + +// Without force, an already-installed file is left untouched and no request is made +// (preserves prior behavior). +func TestDownloadFileSkipsWithoutForce(t *testing.T) { + var requests int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requests++ + w.Header().Set("ETag", `"x"`) + _, _ = io.WriteString(w, "x") + })) + defer srv.Close() + + path := filepath.Join(t.TempDir(), "dstack-runner") + if err := os.WriteFile(path, []byte("preexisting"), 0o755); err != nil { + t.Fatal(err) + } + if err := downloadFile(context.Background(), srv.URL, path, 0o755, false); err != nil { + t.Fatal(err) + } + if requests != 0 { + t.Fatalf("made %d requests; want 0 (existing file without force must not hit the network)", requests) + } +} + +// Force-installing two different versions in turn (e.g. two servers expecting +// different versions) must download each version once, then 304 -- not re-download on +// every swap. +func TestDownloadFileFlipFlopDownloadsEachVersionOnce(t *testing.T) { + var served23, served18 atomic.Int64 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var body, etag string + switch r.URL.Path { + case "/0.20.23/runner": + body, etag = "RUNNER-0.20.23", `"e23"` + case "/0.20.18/runner": + body, etag = "RUNNER-0.20.18-x", `"e18"` // intentionally different length + default: + w.WriteHeader(http.StatusNotFound) + return + } + if r.Header.Get("If-None-Match") == etag { + w.WriteHeader(http.StatusNotModified) + return + } + w.Header().Set("ETag", etag) + if r.URL.Path == "/0.20.23/runner" { + served23.Add(1) + } else { + served18.Add(1) + } + _, _ = io.WriteString(w, body) + })) + defer srv.Close() + + path := filepath.Join(t.TempDir(), "dstack-runner") + url23 := srv.URL + "/0.20.23/runner" + url18 := srv.URL + "/0.20.18/runner" + + for _, u := range []string{url23, url18, url23, url18, url23, url18} { + if err := downloadFile(context.Background(), u, path, 0o755, true); err != nil { + t.Fatalf("downloadFile(%s): %v", u, err) + } + } + + if got := served23.Load(); got != 1 { + t.Fatalf("0.20.23 body served %d times; want 1 (flip-flop must not re-download)", got) + } + if got := served18.Load(); got != 1 { + t.Fatalf("0.20.18 body served %d times; want 1 (flip-flop must not re-download)", got) + } + got, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + if string(got) != "RUNNER-0.20.18-x" { + t.Fatalf("installed content = %q; want the last requested version (RUNNER-0.20.18-x)", got) + } +} diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 8d57f73d6a..0a21d78c74 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -926,8 +926,11 @@ def get_shim_pre_start_commands( return [ f"dlpath=$(sudo mktemp -t {DSTACK_SHIM_BINARY_NAME}.XXXXXXXXXX)", # -sS -- disable progress meter and warnings, but still show errors (unlike bare -s) - f'sudo curl -sS --compressed --connect-timeout 60 --max-time 240 --retry 1 --output "$dlpath" "{url}"', - f'sudo mv "$dlpath" {dstack_shim_binary_path}', + # -f -- fail (non-zero exit, no error body written) on HTTP errors, so a transient + # 403/5xx is never saved as the shim binary; chain `mv` so a failed download + # is never installed (otherwise it would run as a script -> "Syntax error"). + f'sudo curl -fsS --compressed --connect-timeout 60 --max-time 240 --retry 1 --output "$dlpath" "{url}"' + f' && sudo mv "$dlpath" {dstack_shim_binary_path}', f"sudo chmod +x {dstack_shim_binary_path}", f"{{ sudo chcon system_u:object_r:bin_t:s0 {dstack_shim_binary_path} 2>/dev/null || true; }}", f"sudo mkdir {dstack_working_dir} -p",