diff --git a/drivers/vdpa/Kconfig b/drivers/vdpa/Kconfig index 4271c408103e36..d7d32b65610218 100644 --- a/drivers/vdpa/Kconfig +++ b/drivers/vdpa/Kconfig @@ -30,9 +30,7 @@ config IFCVF be called ifcvf. config MLX5_VDPA - bool "MLX5 VDPA support library for ConnectX devices" - depends on MLX5_CORE - default n + bool help Support library for Mellanox VDPA drivers. Provides code that is common for all types of VDPA drivers. The following drivers are planned: @@ -40,7 +38,8 @@ config MLX5_VDPA config MLX5_VDPA_NET tristate "vDPA driver for ConnectX devices" - depends on MLX5_VDPA + select MLX5_VDPA + depends on MLX5_CORE default n help VDPA network driver for ConnectX6 and newer. Provides offloading diff --git a/drivers/vdpa/mlx5/net/mlx5_vnet.c b/drivers/vdpa/mlx5/net/mlx5_vnet.c index 70676a6d16914a..74264e59069517 100644 --- a/drivers/vdpa/mlx5/net/mlx5_vnet.c +++ b/drivers/vdpa/mlx5/net/mlx5_vnet.c @@ -1133,15 +1133,17 @@ static void suspend_vq(struct mlx5_vdpa_net *ndev, struct mlx5_vdpa_virtqueue *m if (!mvq->initialized) return; - if (query_virtqueue(ndev, mvq, &attr)) { - mlx5_vdpa_warn(&ndev->mvdev, "failed to query virtqueue\n"); - return; - } if (mvq->fw_state != MLX5_VIRTIO_NET_Q_OBJECT_STATE_RDY) return; if (modify_virtqueue(ndev, mvq, MLX5_VIRTIO_NET_Q_OBJECT_STATE_SUSPEND)) mlx5_vdpa_warn(&ndev->mvdev, "modify to suspend failed\n"); + + if (query_virtqueue(ndev, mvq, &attr)) { + mlx5_vdpa_warn(&ndev->mvdev, "failed to query virtqueue\n"); + return; + } + mvq->avail_idx = attr.available_index; } static void suspend_vqs(struct mlx5_vdpa_net *ndev) @@ -1411,8 +1413,14 @@ static int mlx5_vdpa_get_vq_state(struct vdpa_device *vdev, u16 idx, struct vdpa struct mlx5_virtq_attr attr; int err; - if (!mvq->initialized) - return -EAGAIN; + /* If the virtq object was destroyed, use the value saved at + * the last minute of suspend_vq. This caters for userspace + * that cares about emulating the index after vq is stopped. + */ + if (!mvq->initialized) { + state->avail_index = mvq->avail_idx; + return 0; + } err = query_virtqueue(ndev, mvq, &attr); if (err) { diff --git a/drivers/vhost/vdpa.c b/drivers/vhost/vdpa.c index 796fe979f997f0..62a9bb0efc5589 100644 --- a/drivers/vhost/vdpa.c +++ b/drivers/vhost/vdpa.c @@ -565,6 +565,9 @@ static int vhost_vdpa_map(struct vhost_vdpa *v, perm_to_iommu_flags(perm)); } + if (r) + vhost_iotlb_del_range(dev->iotlb, iova, iova + size - 1); + return r; } @@ -592,21 +595,19 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v, struct vhost_dev *dev = &v->vdev; struct vhost_iotlb *iotlb = dev->iotlb; struct page **page_list; - unsigned long list_size = PAGE_SIZE / sizeof(struct page *); + struct vm_area_struct **vmas; unsigned int gup_flags = FOLL_LONGTERM; - unsigned long npages, cur_base, map_pfn, last_pfn = 0; - unsigned long locked, lock_limit, pinned, i; + unsigned long map_pfn, last_pfn = 0; + unsigned long npages, lock_limit; + unsigned long i, nmap = 0; u64 iova = msg->iova; + long pinned; int ret = 0; if (vhost_iotlb_itree_first(iotlb, msg->iova, msg->iova + msg->size - 1)) return -EEXIST; - page_list = (struct page **) __get_free_page(GFP_KERNEL); - if (!page_list) - return -ENOMEM; - if (msg->perm & VHOST_ACCESS_WO) gup_flags |= FOLL_WRITE; @@ -614,61 +615,86 @@ static int vhost_vdpa_process_iotlb_update(struct vhost_vdpa *v, if (!npages) return -EINVAL; + page_list = kvmalloc_array(npages, sizeof(struct page *), GFP_KERNEL); + vmas = kvmalloc_array(npages, sizeof(struct vm_area_struct *), + GFP_KERNEL); + if (!page_list || !vmas) { + ret = -ENOMEM; + goto free; + } + mmap_read_lock(dev->mm); - locked = atomic64_add_return(npages, &dev->mm->pinned_vm); lock_limit = rlimit(RLIMIT_MEMLOCK) >> PAGE_SHIFT; - - if (locked > lock_limit) { + if (npages + atomic64_read(&dev->mm->pinned_vm) > lock_limit) { ret = -ENOMEM; - goto out; + goto unlock; } - cur_base = msg->uaddr & PAGE_MASK; - iova &= PAGE_MASK; + pinned = pin_user_pages(msg->uaddr & PAGE_MASK, npages, gup_flags, + page_list, vmas); + if (npages != pinned) { + if (pinned < 0) { + ret = pinned; + } else { + unpin_user_pages(page_list, pinned); + ret = -ENOMEM; + } + goto unlock; + } - while (npages) { - pinned = min_t(unsigned long, npages, list_size); - ret = pin_user_pages(cur_base, pinned, - gup_flags, page_list, NULL); - if (ret != pinned) - goto out; - - if (!last_pfn) - map_pfn = page_to_pfn(page_list[0]); - - for (i = 0; i < ret; i++) { - unsigned long this_pfn = page_to_pfn(page_list[i]); - u64 csize; - - if (last_pfn && (this_pfn != last_pfn + 1)) { - /* Pin a contiguous chunk of memory */ - csize = (last_pfn - map_pfn + 1) << PAGE_SHIFT; - if (vhost_vdpa_map(v, iova, csize, - map_pfn << PAGE_SHIFT, - msg->perm)) - goto out; - map_pfn = this_pfn; - iova += csize; + iova &= PAGE_MASK; + map_pfn = page_to_pfn(page_list[0]); + + /* One more iteration to avoid extra vdpa_map() call out of loop. */ + for (i = 0; i <= npages; i++) { + unsigned long this_pfn; + u64 csize; + + /* The last chunk may have no valid PFN next to it */ + this_pfn = i < npages ? page_to_pfn(page_list[i]) : -1UL; + + if (last_pfn && (this_pfn == -1UL || + this_pfn != last_pfn + 1)) { + /* Pin a contiguous chunk of memory */ + csize = last_pfn - map_pfn + 1; + ret = vhost_vdpa_map(v, iova, csize << PAGE_SHIFT, + map_pfn << PAGE_SHIFT, + msg->perm); + if (ret) { + /* + * Unpin the rest chunks of memory on the + * flight with no corresponding vdpa_map() + * calls having been made yet. On the other + * hand, vdpa_unmap() in the failure path + * is in charge of accounting the number of + * pinned pages for its own. + * This asymmetrical pattern of accounting + * is for efficiency to pin all pages at + * once, while there is no other callsite + * of vdpa_map() than here above. + */ + unpin_user_pages(&page_list[nmap], + npages - nmap); + goto out; } - - last_pfn = this_pfn; + atomic64_add(csize, &dev->mm->pinned_vm); + nmap += csize; + iova += csize << PAGE_SHIFT; + map_pfn = this_pfn; } - - cur_base += ret << PAGE_SHIFT; - npages -= ret; + last_pfn = this_pfn; } - /* Pin the rest chunk */ - ret = vhost_vdpa_map(v, iova, (last_pfn - map_pfn + 1) << PAGE_SHIFT, - map_pfn << PAGE_SHIFT, msg->perm); + WARN_ON(nmap != npages); out: - if (ret) { + if (ret) vhost_vdpa_unmap(v, msg->iova, msg->size); - atomic64_sub(npages, &dev->mm->pinned_vm); - } +unlock: mmap_read_unlock(dev->mm); - free_page((unsigned long)page_list); +free: + kvfree(vmas); + kvfree(page_list); return ret; } @@ -810,6 +836,7 @@ static int vhost_vdpa_open(struct inode *inode, struct file *filep) err_init_iotlb: vhost_dev_cleanup(&v->vdev); + kfree(vqs); err: atomic_dec(&v->opened); return r; diff --git a/drivers/vhost/vhost.c b/drivers/vhost/vhost.c index b45519ca66a7e9..9ad45e1d27f0f0 100644 --- a/drivers/vhost/vhost.c +++ b/drivers/vhost/vhost.c @@ -1290,6 +1290,11 @@ static bool vq_access_ok(struct vhost_virtqueue *vq, unsigned int num, vring_used_t __user *used) { + /* If an IOTLB device is present, the vring addresses are + * GIOVAs. Access validation occurs at prefetch time. */ + if (vq->iotlb) + return true; + return access_ok(desc, vhost_get_desc_size(vq, num)) && access_ok(avail, vhost_get_avail_size(vq, num)) && access_ok(used, vhost_get_used_size(vq, num)); @@ -1365,6 +1370,20 @@ bool vhost_log_access_ok(struct vhost_dev *dev) } EXPORT_SYMBOL_GPL(vhost_log_access_ok); +static bool vq_log_used_access_ok(struct vhost_virtqueue *vq, + void __user *log_base, + bool log_used, + u64 log_addr) +{ + /* If an IOTLB device is present, log_addr is a GIOVA that + * will never be logged by log_used(). */ + if (vq->iotlb) + return true; + + return !log_used || log_access_ok(log_base, log_addr, + vhost_get_used_size(vq, vq->num)); +} + /* Verify access for write logging. */ /* Caller should have vq mutex and device mutex */ static bool vq_log_access_ok(struct vhost_virtqueue *vq, @@ -1372,8 +1391,7 @@ static bool vq_log_access_ok(struct vhost_virtqueue *vq, { return vq_memory_access_ok(log_base, vq->umem, vhost_has_feature(vq, VHOST_F_LOG_ALL)) && - (!vq->log_used || log_access_ok(log_base, vq->log_addr, - vhost_get_used_size(vq, vq->num))); + vq_log_used_access_ok(vq, log_base, vq->log_used, vq->log_addr); } /* Can we start vq? */ @@ -1383,10 +1401,6 @@ bool vhost_vq_access_ok(struct vhost_virtqueue *vq) if (!vq_log_access_ok(vq, vq->log_base)) return false; - /* Access validation occurs at prefetch time with IOTLB */ - if (vq->iotlb) - return true; - return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used); } EXPORT_SYMBOL_GPL(vhost_vq_access_ok); @@ -1516,10 +1530,9 @@ static long vhost_vring_set_addr(struct vhost_dev *d, return -EINVAL; /* Also validate log access for used ring if enabled. */ - if ((a.flags & (0x1 << VHOST_VRING_F_LOG)) && - !log_access_ok(vq->log_base, a.log_guest_addr, - sizeof *vq->used + - vq->num * sizeof *vq->used->ring)) + if (!vq_log_used_access_ok(vq, vq->log_base, + a.flags & (0x1 << VHOST_VRING_F_LOG), + a.log_guest_addr)) return -EINVAL; }