合并并发控制PR,并将并发数设置为参数

This commit is contained in:
刘铭 2024-07-13 19:08:40 +08:00
parent eb0e3cc274
commit 6fff3fe46a
7 changed files with 10 additions and 7 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -19,6 +19,7 @@ func main() {
username := flag.String("username", "", "username")
password := flag.String("password", "", "password")
tags := flag.Bool("tags", false, "获取tag列表")
syncCount := flag.Int("sync", 3, "并发下载数量")
var registry string
flag.StringVar(&registry, "registry", "registry-1.docker.io", "指定镜像仓库")
@ -67,7 +68,7 @@ func main() {
client.SetClient(http.DefaultClient)
}
var err = client.Install(registry, pkg, tag, *arch, *printInfo, *tags, *username, *password)
var err = client.Install(*syncCount, registry, pkg, tag, *arch, *printInfo, *tags, *username, *password)
if err != nil {
logrus.Fatalln("下载发生错误", err)
}

View File

@ -79,11 +79,13 @@ type TagList struct {
Tags []string
}
type SyncSignal struct{}
func (m *Client) SetClient(c *http.Client) {
m.c = c
}
func (m *Client) Install(_registry, d, tag string, arch string, printInfo bool, onlyGetTag bool, username string, password string) (err error) {
func (m *Client) Install(syncCount int, _registry, d, tag string, arch string, printInfo bool, onlyGetTag bool, username string, password string) (err error) {
var authUrl = _authUrl
var regService = _regService
resp, err := m.c.Get(fmt.Sprintf("https://%s/v2/", _registry))
@ -232,7 +234,7 @@ func (m *Client) Install(_registry, d, tag string, arch string, printInfo bool,
resp.Body.Close()
logrus.Infof("获得Manifest信息共%d层需要下载", len(info.Layers))
err = m.download(_registry, d, tag, info.Config.Digest, authHeader, info.Layers)
err = m.download(syncCount, _registry, d, tag, info.Config.Digest, authHeader, info.Layers)
if err != nil {
goto response
@ -273,7 +275,7 @@ func (m *Client) getTokenWithBasicAuth(url, service, repository, username, passw
return "", err
}
func (m *Client) download(_registry, d, tag string, digest digest.Digest, authHeader http.Header, layers []Layer) (err error) {
func (m *Client) download(syncCount int, _registry, d, tag string, digest digest.Digest, authHeader http.Header, layers []Layer) (err error) {
var tmpDir = fmt.Sprintf("tmp_%s_%s", d, tag)
err = os.MkdirAll(tmpDir, 0777)
if err == nil {
@ -308,9 +310,9 @@ func (m *Client) download(_registry, d, tag string, digest digest.Digest, authHe
parentid := ""
var fakeLayerId string
var downloadStatus = make(map[int]bool)
var notifyChan = make(chan int, 3)
var notifyChan = make(chan int, 1)
//限制并发下载数为3
var ch = make(chan struct{}, 3)
var ch = make(chan SyncSignal, syncCount)
for n, layer := range layers {
namer := sha256.New()
namer.Write([]byte(parentid + "\n" + layer.Digest + "\n"))
@ -332,7 +334,7 @@ func (m *Client) download(_registry, d, tag string, digest digest.Digest, authHe
copyedHeader[k] = v
}
go func(fakeLayerId string, layer Layer, n int, notifyChan chan int, layerInfo *LayerInfo, tmpDir string, _registry string, d string, authHeader http.Header) {
ch <- struct{}{}
ch <- SyncSignal{}
er := m.downloadLayer(fakeLayerId, &layer, layerInfo, tmpDir, _registry, d, authHeader)
if er != nil {
logrus.Errorf("下载第%d/%d层失败:%s", n+1, len(layers), err)