Skip to content

Commit

Permalink
RMDA/odp: Consolidate umem_odp initialization
Browse files Browse the repository at this point in the history
This is done in two different places, consolidate all the post-allocation
initialization into a single function.

Link: https://lore.kernel.org/r/[email protected]
Signed-off-by: Leon Romanovsky <[email protected]>
Signed-off-by: Jason Gunthorpe <[email protected]>
  • Loading branch information
jgunthorpe committed Aug 21, 2019
1 parent fd7dbf0 commit 22d79c9
Showing 1 changed file with 86 additions and 114 deletions.
200 changes: 86 additions & 114 deletions drivers/infiniband/core/umem_odp.c
Original file line number Diff line number Diff line change
Expand Up @@ -171,23 +171,6 @@ static const struct mmu_notifier_ops ib_umem_notifiers = {
.invalidate_range_end = ib_umem_notifier_invalidate_range_end,
};

static void add_umem_to_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;

down_write(&per_mm->umem_rwsem);
/*
* Note that the representation of the intervals in the interval tree
* considers the ending point as contained in the interval, while the
* function ib_umem_end returns the first address which is not
* contained in the umem.
*/
umem_odp->interval_tree.start = ib_umem_start(umem_odp);
umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;
interval_tree_insert(&umem_odp->interval_tree, &per_mm->umem_tree);
up_write(&per_mm->umem_rwsem);
}

static void remove_umem_from_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext_per_mm *per_mm = umem_odp->per_mm;
Expand Down Expand Up @@ -237,33 +220,23 @@ static struct ib_ucontext_per_mm *alloc_per_mm(struct ib_ucontext *ctx,
return ERR_PTR(ret);
}

static int get_per_mm(struct ib_umem_odp *umem_odp)
static struct ib_ucontext_per_mm *get_per_mm(struct ib_umem_odp *umem_odp)
{
struct ib_ucontext *ctx = umem_odp->umem.context;
struct ib_ucontext_per_mm *per_mm;

lockdep_assert_held(&ctx->per_mm_list_lock);

/*
* Generally speaking we expect only one or two per_mm in this list,
* so no reason to optimize this search today.
*/
mutex_lock(&ctx->per_mm_list_lock);
list_for_each_entry(per_mm, &ctx->per_mm_list, ucontext_list) {
if (per_mm->mm == umem_odp->umem.owning_mm)
goto found;
}

per_mm = alloc_per_mm(ctx, umem_odp->umem.owning_mm);
if (IS_ERR(per_mm)) {
mutex_unlock(&ctx->per_mm_list_lock);
return PTR_ERR(per_mm);
return per_mm;
}

found:
umem_odp->per_mm = per_mm;
per_mm->odp_mrs_count++;
mutex_unlock(&ctx->per_mm_list_lock);

return 0;
return alloc_per_mm(ctx, umem_odp->umem.owning_mm);
}

static void free_per_mm(struct rcu_head *rcu)
Expand Down Expand Up @@ -304,79 +277,114 @@ static void put_per_mm(struct ib_umem_odp *umem_odp)
mmu_notifier_call_srcu(&per_mm->rcu, free_per_mm);
}

static inline int ib_init_umem_odp(struct ib_umem_odp *umem_odp,
struct ib_ucontext_per_mm *per_mm)
{
struct ib_ucontext *ctx = umem_odp->umem.context;
int ret;

umem_odp->umem.is_odp = 1;
if (!umem_odp->is_implicit_odp) {
size_t pages = ib_umem_odp_num_pages(umem_odp);

if (!pages)
return -EINVAL;

/*
* Note that the representation of the intervals in the
* interval tree considers the ending point as contained in
* the interval, while the function ib_umem_end returns the
* first address which is not contained in the umem.
*/
umem_odp->interval_tree.start = ib_umem_start(umem_odp);
umem_odp->interval_tree.last = ib_umem_end(umem_odp) - 1;

umem_odp->page_list = vzalloc(
array_size(sizeof(*umem_odp->page_list), pages));
if (!umem_odp->page_list)
return -ENOMEM;

umem_odp->dma_list =
vzalloc(array_size(sizeof(*umem_odp->dma_list), pages));
if (!umem_odp->dma_list) {
ret = -ENOMEM;
goto out_page_list;
}
}

mutex_lock(&ctx->per_mm_list_lock);
if (!per_mm) {
per_mm = get_per_mm(umem_odp);
if (IS_ERR(per_mm)) {
ret = PTR_ERR(per_mm);
goto out_unlock;
}
}
umem_odp->per_mm = per_mm;
per_mm->odp_mrs_count++;
mutex_unlock(&ctx->per_mm_list_lock);

mutex_init(&umem_odp->umem_mutex);
init_completion(&umem_odp->notifier_completion);

if (!umem_odp->is_implicit_odp) {
down_write(&per_mm->umem_rwsem);
interval_tree_insert(&umem_odp->interval_tree,
&per_mm->umem_tree);
up_write(&per_mm->umem_rwsem);
}

return 0;

out_unlock:
mutex_unlock(&ctx->per_mm_list_lock);
vfree(umem_odp->dma_list);
out_page_list:
vfree(umem_odp->page_list);
return ret;
}

struct ib_umem_odp *ib_alloc_odp_umem(struct ib_umem_odp *root,
unsigned long addr, size_t size)
{
struct ib_ucontext_per_mm *per_mm = root->per_mm;
struct ib_ucontext *ctx = per_mm->context;
/*
* Caller must ensure that root cannot be freed during the call to
* ib_alloc_odp_umem.
*/
struct ib_umem_odp *odp_data;
struct ib_umem *umem;
int pages = size >> PAGE_SHIFT;
int ret;

if (!size)
return ERR_PTR(-EINVAL);

odp_data = kzalloc(sizeof(*odp_data), GFP_KERNEL);
if (!odp_data)
return ERR_PTR(-ENOMEM);
umem = &odp_data->umem;
umem->context = ctx;
umem->context = root->umem.context;
umem->length = size;
umem->address = addr;
odp_data->page_shift = PAGE_SHIFT;
umem->writable = root->umem.writable;
umem->is_odp = 1;
odp_data->per_mm = per_mm;
umem->owning_mm = per_mm->mm;
mmgrab(umem->owning_mm);

mutex_init(&odp_data->umem_mutex);
init_completion(&odp_data->notifier_completion);

odp_data->page_list =
vzalloc(array_size(pages, sizeof(*odp_data->page_list)));
if (!odp_data->page_list) {
ret = -ENOMEM;
goto out_odp_data;
}
umem->owning_mm = root->umem.owning_mm;
odp_data->page_shift = PAGE_SHIFT;

odp_data->dma_list =
vzalloc(array_size(pages, sizeof(*odp_data->dma_list)));
if (!odp_data->dma_list) {
ret = -ENOMEM;
goto out_page_list;
ret = ib_init_umem_odp(odp_data, root->per_mm);
if (ret) {
kfree(odp_data);
return ERR_PTR(ret);
}

/*
* Caller must ensure that the umem_odp that the per_mm came from
* cannot be freed during the call to ib_alloc_odp_umem.
*/
mutex_lock(&ctx->per_mm_list_lock);
per_mm->odp_mrs_count++;
mutex_unlock(&ctx->per_mm_list_lock);
add_umem_to_per_mm(odp_data);
mmgrab(umem->owning_mm);

return odp_data;

out_page_list:
vfree(odp_data->page_list);
out_odp_data:
mmdrop(umem->owning_mm);
kfree(odp_data);
return ERR_PTR(ret);
}
EXPORT_SYMBOL(ib_alloc_odp_umem);

int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
{
struct ib_umem *umem = &umem_odp->umem;
/*
* NOTE: This must called in a process context where umem->owning_mm
* == current->mm
*/
struct mm_struct *mm = umem->owning_mm;
int ret_val;
struct mm_struct *mm = umem_odp->umem.owning_mm;

if (umem_odp->umem.address == 0 && umem_odp->umem.length == 0)
umem_odp->is_implicit_odp = 1;
Expand All @@ -397,43 +405,7 @@ int ib_umem_odp_get(struct ib_umem_odp *umem_odp, int access)
up_read(&mm->mmap_sem);
}

mutex_init(&umem_odp->umem_mutex);

init_completion(&umem_odp->notifier_completion);

if (!umem_odp->is_implicit_odp) {
if (!ib_umem_odp_num_pages(umem_odp))
return -EINVAL;

umem_odp->page_list =
vzalloc(array_size(sizeof(*umem_odp->page_list),
ib_umem_odp_num_pages(umem_odp)));
if (!umem_odp->page_list)
return -ENOMEM;

umem_odp->dma_list =
vzalloc(array_size(sizeof(*umem_odp->dma_list),
ib_umem_odp_num_pages(umem_odp)));
if (!umem_odp->dma_list) {
ret_val = -ENOMEM;
goto out_page_list;
}
}

ret_val = get_per_mm(umem_odp);
if (ret_val)
goto out_dma_list;

if (!umem_odp->is_implicit_odp)
add_umem_to_per_mm(umem_odp);

return 0;

out_dma_list:
vfree(umem_odp->dma_list);
out_page_list:
vfree(umem_odp->page_list);
return ret_val;
return ib_init_umem_odp(umem_odp, NULL);
}

void ib_umem_odp_release(struct ib_umem_odp *umem_odp)
Expand Down

0 comments on commit 22d79c9

Please sign in to comment.