Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 46 additions & 17 deletions cmd/oras/root/cp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package root
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"slices"
"strings"

Expand All @@ -30,6 +32,7 @@ import (
"oras.land/oras-go/v2/registry"
"oras.land/oras-go/v2/registry/remote"
"oras.land/oras-go/v2/registry/remote/auth"
"oras.land/oras-go/v2/registry/remote/errcode"
"oras.land/oras/cmd/oras/internal/argument"
"oras.land/oras/cmd/oras/internal/command"
"oras.land/oras/cmd/oras/internal/display"
Expand Down Expand Up @@ -177,6 +180,7 @@ func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOn
return []string{mountRepo}, nil
}
}

dst, err = copyHandler.StartTracking(dst)
if err != nil {
return desc, err
Expand All @@ -187,34 +191,59 @@ func doCopy(ctx context.Context, copyHandler status.CopyHandler, src oras.ReadOn
err = stopErr
}
}()

// Hook up handlers
extendedCopyGraphOptions.OnCopySkipped = copyHandler.OnCopySkipped
extendedCopyGraphOptions.PreCopy = copyHandler.PreCopy
extendedCopyGraphOptions.PostCopy = copyHandler.PostCopy
extendedCopyGraphOptions.OnMounted = copyHandler.OnMounted

rOpts := oras.DefaultResolveOptions
rOpts.TargetPlatform = opts.Platform.Platform
if opts.recursive {
desc, err = oras.Resolve(ctx, src, opts.From.Reference, rOpts)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to resolve %s: %w", opts.From.Reference, err)

// Define the execution logic as a closure so we can retry it
executeCopy := func(copyOpts oras.ExtendedCopyGraphOptions) (ocispec.Descriptor, error) {
if opts.recursive {
root, resolveErr := oras.Resolve(ctx, src, opts.From.Reference, rOpts)
if resolveErr != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to resolve %s: %w", opts.From.Reference, resolveErr)
}
return root, recursiveCopy(ctx, src, dst, opts.To.Reference, root, copyOpts)
}
err = recursiveCopy(ctx, src, dst, opts.To.Reference, desc, extendedCopyGraphOptions)
} else {

if opts.To.Reference == "" {
desc, err = oras.Resolve(ctx, src, opts.From.Reference, rOpts)
if err != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to resolve %s: %w", opts.From.Reference, err)
}
err = oras.CopyGraph(ctx, src, dst, desc, extendedCopyGraphOptions.CopyGraphOptions)
} else {
copyOptions := oras.CopyOptions{
CopyGraphOptions: extendedCopyGraphOptions.CopyGraphOptions,
root, resolveErr := oras.Resolve(ctx, src, opts.From.Reference, rOpts)
if resolveErr != nil {
return ocispec.Descriptor{}, fmt.Errorf("failed to resolve %s: %w", opts.From.Reference, resolveErr)
}
if opts.Platform.Platform != nil {
copyOptions.WithTargetPlatform(opts.Platform.Platform)
return root, oras.CopyGraph(ctx, src, dst, root, copyOpts.CopyGraphOptions)
}

// Standard copy
stdCopyOpts := oras.CopyOptions{
CopyGraphOptions: copyOpts.CopyGraphOptions,
}
if opts.Platform.Platform != nil {
stdCopyOpts.WithTargetPlatform(opts.Platform.Platform)
}
return oras.Copy(ctx, src, opts.From.Reference, dst, opts.To.Reference, stdCopyOpts)
}

desc, err = executeCopy(extendedCopyGraphOptions)

// Mount failed due to permissions, retry without mounting
if err != nil && extendedCopyGraphOptions.MountFrom != nil {
var copyErr *oras.CopyError
if errors.As(err, &copyErr) && copyErr.Op == "Mount" {
var errResp *errcode.ErrorResponse
if errors.As(copyErr.Err, &errResp) {
if errResp.StatusCode == http.StatusUnauthorized ||
errResp.StatusCode == http.StatusForbidden {
// Disable mounting and retry
extendedCopyGraphOptions.MountFrom = nil
desc, err = executeCopy(extendedCopyGraphOptions)
}
}
desc, err = oras.Copy(ctx, src, opts.From.Reference, dst, opts.To.Reference, copyOptions)
}
}
// leave the CopyError to oerrors.Modifier for prefix processing
Expand Down
116 changes: 116 additions & 0 deletions cmd/oras/root/cp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"net/url"
"os"
"strings"
"sync/atomic"
"testing"

"github.com/opencontainers/go-digest"
Expand Down Expand Up @@ -204,6 +205,121 @@ func Test_doCopy_mounted(t *testing.T) {
}
}

func Test_doCopy_mountFallback(t *testing.T) {
// Test that copy falls back to regular upload when mount fails with 401/403
repoFromFallback := "from-fallback"
repoToFallback := "to-fallback"

var uploadSessionStarted atomic.Bool
var blobUploaded atomic.Bool

// test server that returns 401 for mount but allows regular upload
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.URL.Path == fmt.Sprintf("/v2/%s/manifests/%s", repoFromFallback, manifestDigest) &&
r.Method == http.MethodHead:
w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest)
w.Header().Set("Content-Length", fmt.Sprint(len(manifestContent)))
w.WriteHeader(http.StatusOK)
case r.URL.Path == fmt.Sprintf("/v2/%s/manifests/%s", repoFromFallback, manifestDigest) &&
r.Method == http.MethodGet:
w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest)
w.Header().Set("Content-Length", fmt.Sprint(len(manifestContent)))
_, _ = w.Write(manifestContent)
case r.URL.Path == fmt.Sprintf("/v2/%s/blobs/%s", repoFromFallback, configDigest) &&
r.Method == http.MethodGet:
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprint(len(configContent)))
_, _ = w.Write(configContent)
case r.URL.Path == fmt.Sprintf("/v2/%s/manifests/%s", repoToFallback, manifestDigest) &&
r.Method == http.MethodHead:
w.WriteHeader(http.StatusNotFound)
case r.URL.Path == fmt.Sprintf("/v2/%s/blobs/%s", repoToFallback, configDigest) &&
r.Method == http.MethodHead:
// check if blob exists before upload - not found initially
if blobUploaded.Load() {
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", fmt.Sprint(len(configContent)))
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusNotFound)
}
case r.URL.Path == fmt.Sprintf("/v2/%s/blobs/uploads/", repoToFallback) &&
r.URL.Query().Get("mount") != "" &&
r.Method == http.MethodPost:
// Return 401 Unauthorized for mount requests to simulate permission denial
w.WriteHeader(http.StatusUnauthorized)
case r.URL.Path == fmt.Sprintf("/v2/%s/blobs/uploads/", repoToFallback) &&
r.URL.Query().Get("mount") == "" &&
r.Method == http.MethodPost:
// Regular blob upload initiation - allow this
uploadSessionStarted.Store(true)
w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/uploads/session123", repoToFallback))
w.WriteHeader(http.StatusAccepted)
case strings.HasPrefix(r.URL.Path, fmt.Sprintf("/v2/%s/blobs/uploads/", repoToFallback)) &&
r.Method == http.MethodPut:
// Blob upload completion
blobUploaded.Store(true)
w.Header().Set("Docker-Content-Digest", configDigest)
w.WriteHeader(http.StatusCreated)
case r.URL.Path == fmt.Sprintf("/v2/%s/manifests/%s", repoToFallback, manifestDigest) &&
r.Method == http.MethodPut:
w.WriteHeader(http.StatusCreated)
case r.URL.Path == fmt.Sprintf("/v2/%s/manifests/%s", repoToFallback, manifestDigest) &&
r.Method == http.MethodGet:
w.Header().Set("Content-Type", ocispec.MediaTypeImageManifest)
w.Header().Set("Content-Length", fmt.Sprint(len(manifestContent)))
_, _ = w.Write(manifestContent)
default:
w.WriteHeader(http.StatusNotAcceptable)
}
}))
defer ts.Close()

uri, _ := url.Parse(ts.URL)
testHost := "localhost:" + uri.Port()

// prepare
pty, child, err := testutils.NewPty()
if err != nil {
t.Fatal(err)
}
defer func() { _ = child.Close() }()
var opts copyOptions
opts.TTY = child
opts.From.Reference = manifestDigest
// mocked repositories
from, err := remote.NewRepository(fmt.Sprintf("%s/%s", testHost, repoFromFallback))
if err != nil {
t.Fatal(err)
}
from.PlainHTTP = true
to, err := remote.NewRepository(fmt.Sprintf("%s/%s", testHost, repoToFallback))
if err != nil {
t.Fatal(err)
}
to.PlainHTTP = true
handler := status.NewTTYCopyHandler(opts.TTY)

// test
_, err = doCopy(context.Background(), handler, from, to, &opts)
if err != nil {
t.Fatal(err)
}

// validate that regular upload was used (fallback succeeded)
if !uploadSessionStarted.Load() {
t.Error("expected regular upload session to be started after mount failure")
}
if !blobUploaded.Load() {
t.Error("expected blob to be uploaded via regular upload after mount failure")
}
// validate output shows "Copied" instead of "Mounted"
if err = testutils.MatchPty(pty, child, "Copied", configMediaType, "100.00%", configDigest); err != nil {
t.Fatal(err)
}
}

func Test_prepareCopyOption_nonIndex(t *testing.T) {
ctx := context.Background()
root := ocispec.Descriptor{
Expand Down
Loading