diff --git a/src/evaluators/state/PHAL_LoadStateField.hpp b/src/evaluators/state/PHAL_LoadStateField.hpp index ec332a178f..6ef1db1b9b 100644 --- a/src/evaluators/state/PHAL_LoadStateField.hpp +++ b/src/evaluators/state/PHAL_LoadStateField.hpp @@ -37,13 +37,11 @@ class LoadStateFieldBase : public PHX::EvaluatorWithBaseImpl, using ExecutionSpace = typename PHX::Device::execution_space; - PHX::MDField data; + PHX::MDField field; std::string fieldName; std::string stateName; MDFieldMemoizer memoizer; - - MDFieldVectorRight dataVec; }; template @@ -65,13 +63,11 @@ class LoadStateField : public PHX::EvaluatorWithBaseImpl, using ExecutionSpace = typename PHX::Device::execution_space; - PHX::MDField data; + PHX::MDField field; std::string fieldName; std::string stateName; MDFieldMemoizer memoizer; - - MDFieldVectorRight dataVec; }; // Shortcut names diff --git a/src/evaluators/state/PHAL_LoadStateField_Def.hpp b/src/evaluators/state/PHAL_LoadStateField_Def.hpp index 654b76750e..3ff7a3f2e8 100644 --- a/src/evaluators/state/PHAL_LoadStateField_Def.hpp +++ b/src/evaluators/state/PHAL_LoadStateField_Def.hpp @@ -22,10 +22,9 @@ LoadStateFieldBase(const Teuchos::ParameterList& p) fieldName = p.get("Field Name"); stateName = p.get("State Name"); - PHX::MDField f(fieldName, p.get >("State Field Layout") ); - data = f; + field = PHX::MDField(fieldName, p.get >("State Field Layout") ); - this->addEvaluatedField(data); + this->addEvaluatedField(field); this->setName("LoadStateField("+stateName+")"+PHX::print()); } @@ -34,7 +33,7 @@ template void LoadStateFieldBase::postRegistrationSetup(typename Traits::SetupData d, PHX::FieldManager& fm) { - this->utils.setFieldData(data,fm); + this->utils.setFieldData(field,fm); d.fill_field_dependencies(this->dependentFields(),this->evaluatedFields()); if (d.memoizer_active()) memoizer.enable_memoizer(); @@ -51,15 +50,13 @@ void LoadStateFieldBase::evaluateFields(typename Trai // whomever changed the data. const auto& stateToLoad = (*workset.stateArrayPtr)[stateName]; auto stateData = stateToLoad.dev(); - const int stateToLoad_size = stateToLoad.size(); - MDFieldVectorRight g(data); - dataVec = g; + ALBANY_ASSERT (stateData.rank() <= 3, "Current implementation supports only views with rank up to 3. If larger rank is needed modify code below"); Kokkos::parallel_for(this->getName(), - Kokkos::RangePolicy(0,data.size()), - KOKKOS_CLASS_LAMBDA(const int i) { - dataVec[i] = (i < stateToLoad_size) ? stateData(i) : 0.0; + Kokkos::MDRangePolicy>({0,0,0},{stateData.extent(0),stateData.extent(1),stateData.extent(2)}), + KOKKOS_CLASS_LAMBDA(const int i, const int j, const int k) { + field.access(i,j,k) = stateData.access(i,j,k); //works also when rank is less than 3 }); } @@ -70,10 +67,10 @@ LoadStateField(const Teuchos::ParameterList& p) fieldName = p.get("Field Name"); stateName = p.get("State Name"); - PHX::MDField f(fieldName, p.get >("State Field Layout") ); - data = f; - this->addEvaluatedField(data); + field = PHX::MDField(fieldName, p.get >("State Field Layout") ); + + this->addEvaluatedField(field); this->setName("Load State Field"+PHX::print()); } @@ -82,7 +79,7 @@ template void LoadStateField::postRegistrationSetup(typename Traits::SetupData d, PHX::FieldManager& fm) { - this->utils.setFieldData(data,fm); + this->utils.setFieldData(field,fm); d.fill_field_dependencies(this->dependentFields(),this->evaluatedFields()); if (d.memoizer_active()) memoizer.enable_memoizer(); @@ -99,15 +96,13 @@ void LoadStateField::evaluateFields(typename Traits::EvalData wor // whomever changed the data. const auto& stateToLoad = (*workset.stateArrayPtr)[stateName]; auto stateData = stateToLoad.dev(); - const int stateToLoad_size = stateToLoad.size(); - MDFieldVectorRight g(data); - dataVec = g; + ALBANY_ASSERT (stateData.rank() <= 3, "Current implementation supports only views with rank up to 3. If larger rank is needed modify code below"); Kokkos::parallel_for(this->getName(), - Kokkos::RangePolicy(0,data.size()), - KOKKOS_CLASS_LAMBDA(const int i) { - dataVec[i] = (i < stateToLoad_size) ? stateData(i) : 0.0; + Kokkos::MDRangePolicy>({0,0,0},{stateData.extent(0),stateData.extent(1),stateData.extent(2)}), + KOKKOS_CLASS_LAMBDA(const int i, const int j, const int k) { + field.access(i,j,k) = stateData.access(i,j,k); //works also when rank is less than 3 }); }