diff --git a/src/caffe/mkldnn_memory.cpp b/src/caffe/mkldnn_memory.cpp index b53cac492..13691c4bb 100644 --- a/src/caffe/mkldnn_memory.cpp +++ b/src/caffe/mkldnn_memory.cpp @@ -122,14 +122,13 @@ void MKLDNNMemoryDescriptor::create_reorder_to_prv(void* cpu_ptr CHECK(this->_usr_memory_pd); CHECK(this->_prv_memory_pd); CHECK(this->_reorder_usr2prv_pd); - if (this->_cpu_ptr == NULL) - this->_cpu_ptr = cpu_ptr; - else - CHECK_EQ(this->_cpu_ptr, cpu_ptr); - if(this->_usr_memory == NULL) + + if(this->_usr_memory == NULL || this->_cpu_ptr != cpu_ptr) this->_usr_memory.reset(new memory(*this->_usr_memory_pd, cpu_ptr)); - if(this->_reorder_usr2prv.aprimitive == NULL) + if(this->_reorder_usr2prv.aprimitive == NULL || this->_cpu_ptr != cpu_ptr) this->_reorder_usr2prv.reset(new reorder(*this->_reorder_usr2prv_pd, *this->_usr_memory, *this->get_prv_memory())); + + this->_cpu_ptr = cpu_ptr; } template @@ -157,16 +156,13 @@ void MKLDNNMemoryDescriptor::create_reorder_from_prv(void* cpu_p CHECK(this->_usr_memory_pd); CHECK(this->_prv_memory_pd); CHECK(this->_reorder_prv2usr_pd); - if (this->_cpu_ptr == NULL) - this->_cpu_ptr = cpu_ptr; - else - CHECK_EQ(this->_cpu_ptr, cpu_ptr); - if(this->_usr_memory == NULL) + if(this->_usr_memory == NULL || this->_cpu_ptr != cpu_ptr) this->_usr_memory.reset(new memory(*this->_usr_memory_pd, cpu_ptr)); - if(this->_reorder_prv2usr.aprimitive == NULL) { + if(this->_reorder_prv2usr.aprimitive == NULL || this->_cpu_ptr != cpu_ptr) { CHECK(this->aprimitive()); this->_reorder_prv2usr.aprimitive.reset(new reorder(*this->_reorder_prv2usr_pd, *this->aprimitive(), *this->_usr_memory)); } + this->_cpu_ptr = cpu_ptr; } template